From 3d7ecaccfad6a9d5194f4adbe8fc736c5e907668 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sun, 15 Oct 2023 15:14:54 +0200
Subject: [PATCH 01/35] manual annotations
---
pyproject.toml | 27 ++++---
trio/_core/__init__.py | 2 +-
trio/_core/_tests/test_guest_mode.py | 24 +++---
trio/_core/_tests/test_instrumentation.py | 97 ++++++++++++-----------
trio/_core/_tests/test_ki.py | 76 +++++++++---------
trio/_core/_tests/test_local.py | 51 ++++++------
trio/_core/_tests/test_thread_cache.py | 8 +-
trio/_core/_tests/test_unbounded_queue.py | 32 ++++----
trio/lowlevel.py | 1 +
9 files changed, 175 insertions(+), 143 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 73b110ebac..e12f95c309 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -80,25 +80,34 @@ disallow_untyped_defs = true
check_untyped_defs = true
disallow_untyped_calls = false
-# files not yet fully typed
+
+# partially typed tests
[[tool.mypy.overrides]]
module = [
-# internal
-"trio/_windows_pipes",
-
-# tests
-"trio/testing/_fake_net",
"trio/_core/_tests/test_guest_mode",
-"trio/_core/_tests/test_instrumentation",
"trio/_core/_tests/test_ki",
-"trio/_core/_tests/test_local",
"trio/_core/_tests/test_mock_clock",
"trio/_core/_tests/test_multierror",
"trio/_core/_tests/test_multierror_scripts/ipython_custom_exc",
"trio/_core/_tests/test_multierror_scripts/simple_excepthook",
"trio/_core/_tests/test_parking_lot",
"trio/_core/_tests/test_thread_cache",
-"trio/_core/_tests/test_unbounded_queue",
+]
+check_untyped_defs = true
+disallow_any_decorated = false
+disallow_any_generics = false
+disallow_any_unimported = false
+disallow_incomplete_defs = false
+disallow_untyped_defs = false
+
+# files not yet fully typed
+[[tool.mypy.overrides]]
+module = [
+# internal
+"trio/_windows_pipes",
+
+# tests
+"trio/testing/_fake_net",
"trio/_tests/test_exports",
"trio/_tests/test_file_io",
"trio/_tests/test_highlevel_generic",
diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py
index b9bd0d8cc4..71f5f17eb2 100644
--- a/trio/_core/__init__.py
+++ b/trio/_core/__init__.py
@@ -18,7 +18,7 @@
WouldBlock,
)
from ._ki import currently_ki_protected, disable_ki_protection, enable_ki_protection
-from ._local import RunVar
+from ._local import RunVar, RunVarToken
from ._mock_clock import MockClock
from ._parking_lot import ParkingLot, ParkingLotStatistics
diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py
index 80180be805..21980b9b72 100644
--- a/trio/_core/_tests/test_guest_mode.py
+++ b/trio/_core/_tests/test_guest_mode.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import asyncio
import contextvars
import queue
@@ -10,8 +12,10 @@
import warnings
from functools import partial
from math import inf
+from typing import Callable
import pytest
+from outcome import Outcome
import trio
import trio.testing
@@ -27,11 +31,11 @@
# - final result is returned
# - any unhandled exceptions cause an immediate crash
def trivial_guest_run(trio_fn, *, in_host_after_start=None, **start_guest_run_kwargs):
- todo = queue.Queue()
+ todo: queue.Queue[tuple[str, Outcome]] = queue.Queue()
host_thread = threading.current_thread()
- def run_sync_soon_threadsafe(fn):
+ def run_sync_soon_threadsafe(fn: Callable):
nonlocal todo
if host_thread is threading.current_thread(): # pragma: no cover
crash = partial(
@@ -40,7 +44,7 @@ def run_sync_soon_threadsafe(fn):
todo.put(("run", crash))
todo.put(("run", fn))
- def run_sync_soon_not_threadsafe(fn):
+ def run_sync_soon_not_threadsafe(fn: Callable):
nonlocal todo
if host_thread is not threading.current_thread(): # pragma: no cover
crash = partial(
@@ -49,7 +53,7 @@ def run_sync_soon_not_threadsafe(fn):
todo.put(("run", crash))
todo.put(("run", fn))
- def done_callback(outcome):
+ def done_callback(outcome: Outcome):
nonlocal todo
todo.put(("unwrap", outcome))
@@ -322,18 +326,18 @@ async def get_woken_by_host_deadline(watb_cscope):
# 'sit_in_wait_all_tasks_blocked', we want the test to
# actually end. So in after_io_wait we schedule a second host
# call to tear things down.
- class InstrumentHelper:
- def __init__(self):
+ class InstrumentHelper(trio._abc.Instrument):
+ def __init__(self) -> None:
self.primed = False
- def before_io_wait(self, timeout):
+ def before_io_wait(self, timeout: float) -> None:
print(f"before_io_wait({timeout})")
if timeout == 9999: # pragma: no branch
assert not self.primed
in_host(lambda: set_deadline(cscope, 1e9))
self.primed = True
- def after_io_wait(self, timeout):
+ def after_io_wait(self, timeout: float) -> None:
if self.primed: # pragma: no branch
print("instrument triggered")
in_host(lambda: cscope.cancel())
@@ -429,8 +433,8 @@ def test_guest_mode_on_asyncio():
async def trio_main():
print("trio_main!")
- to_trio, from_aio = trio.open_memory_channel(float("inf"))
- from_trio = asyncio.Queue()
+ to_trio, from_aio = trio.open_memory_channel[int](float("inf"))
+ from_trio = asyncio.Queue[int]()
aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio))
diff --git a/trio/_core/_tests/test_instrumentation.py b/trio/_core/_tests/test_instrumentation.py
index 498a3eb272..8b103dbfdc 100644
--- a/trio/_core/_tests/test_instrumentation.py
+++ b/trio/_core/_tests/test_instrumentation.py
@@ -1,32 +1,37 @@
+from __future__ import annotations
+
+from typing import Container, Iterable, NoReturn
+
import attr
import pytest
from ... import _abc, _core
+from ...lowlevel import Task
from .tutil import check_sequence_matches
@attr.s(eq=False, hash=False)
-class TaskRecorder:
- record = attr.ib(factory=list)
+class TaskRecorder(_abc.Instrument):
+ record: list[tuple[str, Task | None]] = attr.ib(factory=list)
- def before_run(self):
- self.record.append(("before_run",))
+ def before_run(self) -> None:
+ self.record.append(("before_run", None))
- def task_scheduled(self, task):
+ def task_scheduled(self, task: Task) -> None:
self.record.append(("schedule", task))
- def before_task_step(self, task):
+ def before_task_step(self, task: Task) -> None:
assert task is _core.current_task()
self.record.append(("before", task))
- def after_task_step(self, task):
+ def after_task_step(self, task: Task) -> None:
assert task is _core.current_task()
self.record.append(("after", task))
- def after_run(self):
- self.record.append(("after_run",))
+ def after_run(self) -> None:
+ self.record.append(("after_run", None))
- def filter_tasks(self, tasks):
+ def filter_tasks(self, tasks: Container[Task]) -> Iterable[tuple[str, Task | None]]:
for item in self.record:
if item[0] in ("schedule", "before", "after") and item[1] in tasks:
yield item
@@ -34,7 +39,7 @@ def filter_tasks(self, tasks):
yield item
-def test_instruments(recwarn):
+def test_instruments(recwarn: object) -> None:
r1 = TaskRecorder()
r2 = TaskRecorder()
r3 = TaskRecorder()
@@ -44,7 +49,7 @@ def test_instruments(recwarn):
# We use a child task for this, because the main task does some extra
# bookkeeping stuff that can leak into the instrument results, and we
# don't want to deal with it.
- async def task_fn():
+ async def task_fn() -> None:
nonlocal task
task = _core.current_task()
@@ -60,7 +65,7 @@ async def task_fn():
for _ in range(1):
await _core.checkpoint()
- async def main():
+ async def main() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(task_fn)
@@ -75,21 +80,22 @@ async def main():
+ [("before", task), ("after", task), ("after_run",)]
)
assert r1.record == r2.record + r3.record
+ assert task is not None
assert list(r1.filter_tasks([task])) == expected
-def test_instruments_interleave():
+def test_instruments_interleave() -> None:
tasks = {}
- async def two_step1():
+ async def two_step1() -> None:
tasks["t1"] = _core.current_task()
await _core.checkpoint()
- async def two_step2():
+ async def two_step2() -> None:
tasks["t2"] = _core.current_task()
await _core.checkpoint()
- async def main():
+ async def main() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(two_step1)
nursery.start_soon(two_step2)
@@ -121,46 +127,46 @@ async def main():
check_sequence_matches(list(r.filter_tasks(tasks.values())), expected)
-def test_null_instrument():
+def test_null_instrument() -> None:
# undefined instrument methods are skipped
- class NullInstrument:
- def something_unrelated(self):
+ class NullInstrument(_abc.Instrument):
+ def something_unrelated(self) -> None:
pass # pragma: no cover
- async def main():
+ async def main() -> None:
await _core.checkpoint()
_core.run(main, instruments=[NullInstrument()])
-def test_instrument_before_after_run():
+def test_instrument_before_after_run() -> None:
record = []
- class BeforeAfterRun:
- def before_run(self):
+ class BeforeAfterRun(_abc.Instrument):
+ def before_run(self) -> None:
record.append("before_run")
- def after_run(self):
+ def after_run(self) -> None:
record.append("after_run")
- async def main():
+ async def main() -> None:
pass
_core.run(main, instruments=[BeforeAfterRun()])
assert record == ["before_run", "after_run"]
-def test_instrument_task_spawn_exit():
+def test_instrument_task_spawn_exit() -> None:
record = []
- class SpawnExitRecorder:
- def task_spawned(self, task):
+ class SpawnExitRecorder(_abc.Instrument):
+ def task_spawned(self, task: Task) -> None:
record.append(("spawned", task))
- def task_exited(self, task):
+ def task_exited(self, task: Task) -> None:
record.append(("exited", task))
- async def main():
+ async def main() -> Task:
return _core.current_task()
main_task = _core.run(main, instruments=[SpawnExitRecorder()])
@@ -170,20 +176,20 @@ async def main():
# This test also tests having a crash before the initial task is even spawned,
# which is very difficult to handle.
-def test_instruments_crash(caplog):
+def test_instruments_crash(caplog: pytest.LogCaptureFixture) -> None:
record = []
- class BrokenInstrument:
- def task_scheduled(self, task):
+ class BrokenInstrument(_abc.Instrument):
+ def task_scheduled(self, task: Task) -> NoReturn:
record.append("scheduled")
raise ValueError("oops")
- def close(self):
+ def close(self) -> None:
# Shouldn't be called -- tests that the instrument disabling logic
# works right.
record.append("closed") # pragma: no cover
- async def main():
+ async def main() -> Task:
record.append("main ran")
return _core.current_task()
@@ -195,24 +201,25 @@ async def main():
assert ("after", main_task) in r.record
assert ("after_run",) in r.record
# And we got a log message
+ assert caplog.records[0].exc_info is not None
exc_type, exc_value, exc_traceback = caplog.records[0].exc_info
assert exc_type is ValueError
assert str(exc_value) == "oops"
assert "Instrument has been disabled" in caplog.records[0].message
-def test_instruments_monkeypatch():
+def test_instruments_monkeypatch() -> None:
class NullInstrument(_abc.Instrument):
pass
instrument = NullInstrument()
- async def main():
- record = []
+ async def main() -> None:
+ record: list[Task] = []
# Changing the set of hooks implemented by an instrument after
# it's installed doesn't make them start being called right away
- instrument.before_task_step = record.append
+ instrument.before_task_step = record.append # type: ignore[assignment]
await _core.checkpoint()
await _core.checkpoint()
assert len(record) == 0
@@ -233,16 +240,16 @@ async def main():
_core.run(main, instruments=[instrument])
-def test_instrument_that_raises_on_getattr():
- class EvilInstrument:
- def task_exited(self, task):
+def test_instrument_that_raises_on_getattr() -> None:
+ class EvilInstrument(_abc.Instrument):
+ def task_exited(self, task: Task) -> NoReturn:
assert False # pragma: no cover
@property
- def after_run(self):
+ def after_run(self) -> NoReturn:
raise ValueError("oops")
- async def main():
+ async def main() -> None:
with pytest.raises(ValueError):
_core.add_instrument(EvilInstrument())
diff --git a/trio/_core/_tests/test_ki.py b/trio/_core/_tests/test_ki.py
index dc9f2f51e5..0fda688194 100644
--- a/trio/_core/_tests/test_ki.py
+++ b/trio/_core/_tests/test_ki.py
@@ -15,6 +15,7 @@
async_generator = yield_ = None
from ... import _core
+from ..._abc import Instrument
from ..._timeouts import sleep
from ..._util import signal_raise
from ...testing import wait_all_tasks_blocked
@@ -283,33 +284,33 @@ async def raiser(name, record):
# simulated control-C during raiser, which is *unprotected*
print("check 1")
- record = set()
+ record_set: set[str] = set()
async def check_unprotected_kill():
async with _core.open_nursery() as nursery:
- nursery.start_soon(sleeper, "s1", record)
- nursery.start_soon(sleeper, "s2", record)
- nursery.start_soon(raiser, "r1", record)
+ nursery.start_soon(sleeper, "s1", record_set)
+ nursery.start_soon(sleeper, "s2", record_set)
+ nursery.start_soon(raiser, "r1", record_set)
with pytest.raises(KeyboardInterrupt):
_core.run(check_unprotected_kill)
- assert record == {"s1 ok", "s2 ok", "r1 raise ok"}
+ assert record_set == {"s1 ok", "s2 ok", "r1 raise ok"}
# simulated control-C during raiser, which is *protected*, so the KI gets
# delivered to the main task instead
print("check 2")
- record = set()
+ record_set = set()
async def check_protected_kill():
async with _core.open_nursery() as nursery:
- nursery.start_soon(sleeper, "s1", record)
- nursery.start_soon(sleeper, "s2", record)
- nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record)
+ nursery.start_soon(sleeper, "s1", record_set)
+ nursery.start_soon(sleeper, "s2", record_set)
+ nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record_set)
# __aexit__ blocks, and then receives the KI
with pytest.raises(KeyboardInterrupt):
_core.run(check_protected_kill)
- assert record == {"s1 ok", "s2 ok", "r1 cancel ok"}
+ assert record_set == {"s1 ok", "s2 ok", "r1 cancel ok"}
# kill at last moment still raises (run_sync_soon until it raises an
# error, then kill)
@@ -335,33 +336,33 @@ def kill_during_shutdown():
# KI arrives very early, before main is even spawned
print("check 4")
- class InstrumentOfDeath:
- def before_run(self):
+ class InstrumentOfDeath(Instrument):
+ def before_run(self) -> None:
ki_self()
- async def main():
+ async def main_1():
await _core.checkpoint()
with pytest.raises(KeyboardInterrupt):
- _core.run(main, instruments=[InstrumentOfDeath()])
+ _core.run(main_1, instruments=[InstrumentOfDeath()])
# checkpoint_if_cancelled notices pending KI
print("check 5")
@_core.enable_ki_protection
- async def main():
+ async def main_2():
assert _core.currently_ki_protected()
ki_self()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint_if_cancelled()
- _core.run(main)
+ _core.run(main_2)
# KI arrives while main task is not abortable, b/c already scheduled
print("check 6")
@_core.enable_ki_protection
- async def main():
+ async def main_3():
assert _core.currently_ki_protected()
ki_self()
await _core.cancel_shielded_checkpoint()
@@ -370,13 +371,13 @@ async def main():
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
- _core.run(main)
+ _core.run(main_3)
# KI arrives while main task is not abortable, b/c refuses to be aborted
print("check 7")
@_core.enable_ki_protection
- async def main():
+ async def main_4():
assert _core.currently_ki_protected()
ki_self()
task = _core.current_task()
@@ -389,13 +390,13 @@ def abort(_: RaiseCancelT) -> Abort:
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
- _core.run(main)
+ _core.run(main_4)
# KI delivered via slow abort
print("check 8")
@_core.enable_ki_protection
- async def main():
+ async def main_5():
assert _core.currently_ki_protected()
ki_self()
task = _core.current_task()
@@ -409,7 +410,7 @@ def abort(raise_cancel: RaiseCancelT) -> Abort:
assert await _core.wait_task_rescheduled(abort)
await _core.checkpoint()
- _core.run(main)
+ _core.run(main_5)
# KI arrives just before main task exits, so the run_sync_soon machinery
# is still functioning and will accept the callback to deliver the KI, but
@@ -418,42 +419,42 @@ def abort(raise_cancel: RaiseCancelT) -> Abort:
print("check 9")
@_core.enable_ki_protection
- async def main():
+ async def main_6():
ki_self()
with pytest.raises(KeyboardInterrupt):
- _core.run(main)
+ _core.run(main_6)
print("check 10")
# KI in unprotected code, with
# restrict_keyboard_interrupt_to_checkpoints=True
- record = []
+ record_list = []
- async def main():
+ async def main_7():
# We're not KI protected...
assert not _core.currently_ki_protected()
ki_self()
# ...but even after the KI, we keep running uninterrupted...
- record.append("ok")
+ record_list.append("ok")
# ...until we hit a checkpoint:
with pytest.raises(KeyboardInterrupt):
await sleep(10)
- _core.run(main, restrict_keyboard_interrupt_to_checkpoints=True)
- assert record == ["ok"]
- record = []
+ _core.run(main_7, restrict_keyboard_interrupt_to_checkpoints=True)
+ assert record_list == ["ok"]
+ record_list = []
# Exact same code raises KI early if we leave off the argument, doesn't
# even reach the record.append call:
with pytest.raises(KeyboardInterrupt):
- _core.run(main)
- assert record == []
+ _core.run(main_7)
+ assert record_list == []
# KI arrives while main task is inside a cancelled cancellation scope
# the KeyboardInterrupt should take priority
print("check 11")
@_core.enable_ki_protection
- async def main():
+ async def main_8():
assert _core.currently_ki_protected()
with _core.CancelScope() as cancel_scope:
cancel_scope.cancel()
@@ -465,7 +466,7 @@ async def main():
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
- _core.run(main)
+ _core.run(main_8)
def test_ki_is_good_neighbor():
@@ -488,16 +489,17 @@ async def main():
# Regression test for #461
+# don't know if _active not being visible is a problem
def test_ki_with_broken_threads():
thread = threading.main_thread()
# scary!
- original = threading._active[thread.ident]
+ original = threading._active[thread.ident] # type: ignore[attr-defined]
# put this in a try finally so we don't have a chance of cascading a
# breakage down to everything else
try:
- del threading._active[thread.ident]
+ del threading._active[thread.ident] # type: ignore[attr-defined]
@_core.enable_ki_protection
async def inner():
@@ -505,4 +507,4 @@ async def inner():
_core.run(inner)
finally:
- threading._active[thread.ident] = original
+ threading._active[thread.ident] = original # type: ignore[attr-defined]
diff --git a/trio/_core/_tests/test_local.py b/trio/_core/_tests/test_local.py
index d36be0479e..5fdf54b13c 100644
--- a/trio/_core/_tests/test_local.py
+++ b/trio/_core/_tests/test_local.py
@@ -1,16 +1,19 @@
import pytest
+from trio import run
+from trio.lowlevel import RunVar, RunVarToken
+
from ... import _core
# scary runvar tests
-def test_runvar_smoketest():
- t1 = _core.RunVar("test1")
- t2 = _core.RunVar("test2", default="catfish")
+def test_runvar_smoketest() -> None:
+ t1 = RunVar[str]("test1")
+ t2 = RunVar[str]("test2", default="catfish")
assert repr(t1) == ""
- async def first_check():
+ async def first_check() -> None:
with pytest.raises(LookupError):
t1.get()
@@ -23,28 +26,28 @@ async def first_check():
assert t2.get() == "goldfish"
assert t2.get(default="tuna") == "goldfish"
- async def second_check():
+ async def second_check() -> None:
with pytest.raises(LookupError):
t1.get()
assert t2.get() == "catfish"
- _core.run(first_check)
- _core.run(second_check)
+ run(first_check)
+ run(second_check)
-def test_runvar_resetting():
- t1 = _core.RunVar("test1")
- t2 = _core.RunVar("test2", default="dogfish")
- t3 = _core.RunVar("test3")
+def test_runvar_resetting() -> None:
+ t1 = RunVar[str]("test1")
+ t2 = RunVar[str]("test2", default="dogfish")
+ t3 = RunVar[str]("test3")
- async def reset_check():
+ async def reset_check() -> None:
token = t1.set("moonfish")
assert t1.get() == "moonfish"
t1.reset(token)
with pytest.raises(TypeError):
- t1.reset(None)
+ t1.reset(None) # type: ignore[arg-type]
with pytest.raises(LookupError):
t1.get()
@@ -63,18 +66,18 @@ async def reset_check():
with pytest.raises(ValueError):
t1.reset(token3)
- _core.run(reset_check)
+ run(reset_check)
-def test_runvar_sync():
- t1 = _core.RunVar("test1")
+def test_runvar_sync() -> None:
+ t1 = RunVar[str]("test1")
- async def sync_check():
- async def task1():
+ async def sync_check() -> None:
+ async def task1() -> None:
t1.set("plaice")
assert t1.get() == "plaice"
- async def task2(tok):
+ async def task2(tok: str) -> None:
t1.reset(token)
with pytest.raises(LookupError):
@@ -94,11 +97,11 @@ async def task2(tok):
await _core.wait_all_tasks_blocked()
assert t1.get() == "haddock"
- _core.run(sync_check)
+ run(sync_check)
-def test_accessing_runvar_outside_run_call_fails():
- t1 = _core.RunVar("test1")
+def test_accessing_runvar_outside_run_call_fails() -> None:
+ t1 = RunVar[str]("test1")
with pytest.raises(RuntimeError):
t1.set("asdf")
@@ -106,10 +109,10 @@ def test_accessing_runvar_outside_run_call_fails():
with pytest.raises(RuntimeError):
t1.get()
- async def get_token():
+ async def get_token() -> RunVarToken[str]:
return t1.set("ok")
- token = _core.run(get_token)
+ token = run(get_token)
with pytest.raises(RuntimeError):
t1.reset(token)
diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py
index de78443f4e..3cd79ecd8a 100644
--- a/trio/_core/_tests/test_thread_cache.py
+++ b/trio/_core/_tests/test_thread_cache.py
@@ -1,9 +1,13 @@
+from __future__ import annotations
+
import threading
import time
from contextlib import contextmanager
from queue import Queue
+from typing import NoReturn
import pytest
+from outcome import Outcome
from .. import _thread_cache
from .._thread_cache import ThreadCache, start_thread_soon
@@ -11,9 +15,9 @@
def test_thread_cache_basics():
- q = Queue()
+ q = Queue[Outcome]()
- def fn():
+ def fn() -> NoReturn:
raise RuntimeError("hi")
def deliver(outcome):
diff --git a/trio/_core/_tests/test_unbounded_queue.py b/trio/_core/_tests/test_unbounded_queue.py
index cffeed1618..33eb41a5b3 100644
--- a/trio/_core/_tests/test_unbounded_queue.py
+++ b/trio/_core/_tests/test_unbounded_queue.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import itertools
import pytest
@@ -10,8 +12,8 @@
)
-async def test_UnboundedQueue_basic():
- q = _core.UnboundedQueue()
+async def test_UnboundedQueue_basic() -> None:
+ q: _core.UnboundedQueue[str | int | None] = _core.UnboundedQueue()
q.put_nowait("hi")
assert await q.get_batch() == ["hi"]
with pytest.raises(_core.WouldBlock):
@@ -35,17 +37,17 @@ async def test_UnboundedQueue_basic():
repr(q)
-async def test_UnboundedQueue_blocking():
+async def test_UnboundedQueue_blocking() -> None:
record = []
- q = _core.UnboundedQueue()
+ q = _core.UnboundedQueue[int]()
- async def get_batch_consumer():
+ async def get_batch_consumer() -> None:
while True:
batch = await q.get_batch()
assert batch
record.append(batch)
- async def aiter_consumer():
+ async def aiter_consumer() -> None:
async for batch in q:
assert batch
record.append(batch)
@@ -67,8 +69,8 @@ async def aiter_consumer():
nursery.cancel_scope.cancel()
-async def test_UnboundedQueue_fairness():
- q = _core.UnboundedQueue()
+async def test_UnboundedQueue_fairness() -> None:
+ q = _core.UnboundedQueue[int]()
# If there's no-one else around, we can put stuff in and take it out
# again, no problem
@@ -78,7 +80,7 @@ async def test_UnboundedQueue_fairness():
result = None
- async def get_batch(q):
+ async def get_batch(q: _core.UnboundedQueue[int]) -> None:
nonlocal result
result = await q.get_batch()
@@ -95,7 +97,7 @@ async def get_batch(q):
# If two tasks are trying to read, they alternate
record = []
- async def reader(name):
+ async def reader(name: str) -> None:
while True:
record.append((name, await q.get_batch()))
@@ -114,8 +116,8 @@ async def reader(name):
assert record == list(zip(itertools.cycle("ab"), [[i] for i in range(20)]))
-async def test_UnboundedQueue_trivial_yields():
- q = _core.UnboundedQueue()
+async def test_UnboundedQueue_trivial_yields() -> None:
+ q = _core.UnboundedQueue[None]()
q.put_nowait(None)
with assert_checkpoints():
@@ -127,17 +129,17 @@ async def test_UnboundedQueue_trivial_yields():
break
-async def test_UnboundedQueue_no_spurious_wakeups():
+async def test_UnboundedQueue_no_spurious_wakeups() -> None:
# If we have two tasks waiting, and put two items into the queue... then
# only one task wakes up
record = []
- async def getter(q, i):
+ async def getter(q: _core.UnboundedQueue[int], i: int) -> None:
got = await q.get_batch()
record.append((i, got))
async with _core.open_nursery() as nursery:
- q = _core.UnboundedQueue()
+ q = _core.UnboundedQueue[int]()
nursery.start_soon(getter, q, 1)
await wait_all_tasks_blocked()
nursery.start_soon(getter, q, 2)
diff --git a/trio/lowlevel.py b/trio/lowlevel.py
index 25e64975e2..964dabb556 100644
--- a/trio/lowlevel.py
+++ b/trio/lowlevel.py
@@ -15,6 +15,7 @@
RaiseCancelT as RaiseCancelT,
RunStatistics as RunStatistics,
RunVar as RunVar,
+ RunVarToken as RunVarToken,
Task as Task,
TrioToken as TrioToken,
UnboundedQueue as UnboundedQueue,
From 8eb11900a7e4ba59e668790f7c161903ef352ea2 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sun, 15 Oct 2023 15:22:32 +0200
Subject: [PATCH 02/35] pyannotate --safe
---
trio/_tests/test_exports.py | 6 +-
trio/_tests/test_file_io.py | 38 +++---
trio/_tests/test_highlevel_generic.py | 16 +--
trio/_tests/test_highlevel_open_tcp_stream.py | 2 +-
.../_tests/test_highlevel_open_unix_stream.py | 8 +-
trio/_tests/test_highlevel_serve_listeners.py | 20 +--
trio/_tests/test_highlevel_socket.py | 2 +-
trio/_tests/test_highlevel_ssl_helpers.py | 8 +-
trio/_tests/test_path.py | 48 +++----
trio/_tests/test_scheduler_determinism.py | 6 +-
trio/_tests/test_socket.py | 2 +-
trio/_tests/test_ssl.py | 126 +++++++++---------
trio/_tests/test_subprocess.py | 56 ++++----
trio/_tests/test_sync.py | 62 ++++-----
trio/_tests/test_testing.py | 90 ++++++-------
trio/_tests/test_threads.py | 116 ++++++++--------
trio/_tests/test_timeouts.py | 18 +--
trio/_tests/test_tracing.py | 4 +-
trio/_tests/test_unix_pipes.py | 36 ++---
trio/_tests/test_util.py | 26 ++--
trio/_tests/test_wait_for_object.py | 10 +-
trio/_tests/test_windows_pipes.py | 18 +--
trio/_tests/tools/test_gen_exports.py | 6 +-
23 files changed, 363 insertions(+), 361 deletions(-)
diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py
index c3d8a03b63..7a016e86d3 100644
--- a/trio/_tests/test_exports.py
+++ b/trio/_tests/test_exports.py
@@ -34,7 +34,7 @@
Protocol_ext = Protocol # type: ignore[assignment]
-def _ensure_mypy_cache_updated():
+def _ensure_mypy_cache_updated() -> None:
# This pollutes the `empty` dir. Should this be changed?
try:
from mypy.api import run
@@ -59,7 +59,7 @@ def _ensure_mypy_cache_updated():
mypy_cache_updated = True
-def test_core_is_properly_reexported():
+def test_core_is_properly_reexported() -> None:
# Each export from _core should be re-exported by exactly one of these
# three modules:
sources = [trio, trio.lowlevel, trio.testing]
@@ -126,7 +126,7 @@ def iter_modules(
# https://github.com/pypa/setuptools/issues/3274
"ignore:module 'sre_constants' is deprecated:DeprecationWarning",
)
-def test_static_tool_sees_all_symbols(tool, modname, tmpdir):
+def test_static_tool_sees_all_symbols(tool, modname, tmpdir) -> None:
module = importlib.import_module(modname)
def no_underscores(symbols):
diff --git a/trio/_tests/test_file_io.py b/trio/_tests/test_file_io.py
index bae426cf48..863ebe81b0 100644
--- a/trio/_tests/test_file_io.py
+++ b/trio/_tests/test_file_io.py
@@ -28,17 +28,17 @@ def async_file(wrapped):
return trio.wrap_file(wrapped)
-def test_wrap_invalid():
+def test_wrap_invalid() -> None:
with pytest.raises(TypeError):
trio.wrap_file("")
-def test_wrap_non_iobase():
+def test_wrap_non_iobase() -> None:
class FakeFile:
- def close(self): # pragma: no cover
+ def close(self) -> None: # pragma: no cover
pass
- def write(self): # pragma: no cover
+ def write(self) -> None: # pragma: no cover
pass
wrapped = FakeFile()
@@ -53,11 +53,11 @@ def write(self): # pragma: no cover
trio.wrap_file(FakeFile())
-def test_wrapped_property(async_file, wrapped):
+def test_wrapped_property(async_file, wrapped) -> None:
assert async_file.wrapped is wrapped
-def test_dir_matches_wrapped(async_file, wrapped):
+def test_dir_matches_wrapped(async_file, wrapped) -> None:
attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS)
# all supported attrs in wrapped should be available in async_file
@@ -68,9 +68,9 @@ def test_dir_matches_wrapped(async_file, wrapped):
)
-def test_unsupported_not_forwarded():
+def test_unsupported_not_forwarded() -> None:
class FakeFile(io.RawIOBase):
- def unsupported_attr(self): # pragma: no cover
+ def unsupported_attr(self) -> None: # pragma: no cover
pass
async_file = trio.wrap_file(FakeFile())
@@ -121,7 +121,7 @@ def test_type_stubs_match_lists() -> None:
assert found == expected
-def test_sync_attrs_forwarded(async_file, wrapped):
+def test_sync_attrs_forwarded(async_file, wrapped) -> None:
for attr_name in _FILE_SYNC_ATTRS:
if attr_name not in dir(async_file):
continue
@@ -129,7 +129,7 @@ def test_sync_attrs_forwarded(async_file, wrapped):
assert getattr(async_file, attr_name) is getattr(wrapped, attr_name)
-def test_sync_attrs_match_wrapper(async_file, wrapped):
+def test_sync_attrs_match_wrapper(async_file, wrapped) -> None:
for attr_name in _FILE_SYNC_ATTRS:
if attr_name in dir(async_file):
continue
@@ -141,7 +141,7 @@ def test_sync_attrs_match_wrapper(async_file, wrapped):
getattr(wrapped, attr_name)
-def test_async_methods_generated_once(async_file):
+def test_async_methods_generated_once(async_file) -> None:
for meth_name in _FILE_ASYNC_METHODS:
if meth_name not in dir(async_file):
continue
@@ -149,7 +149,7 @@ def test_async_methods_generated_once(async_file):
assert getattr(async_file, meth_name) is getattr(async_file, meth_name)
-def test_async_methods_signature(async_file):
+def test_async_methods_signature(async_file) -> None:
# use read as a representative of all async methods
assert async_file.read.__name__ == "read"
assert async_file.read.__qualname__ == "AsyncIOWrapper.read"
@@ -157,7 +157,7 @@ def test_async_methods_signature(async_file):
assert "io.StringIO.read" in async_file.read.__doc__
-async def test_async_methods_wrap(async_file, wrapped):
+async def test_async_methods_wrap(async_file, wrapped) -> None:
for meth_name in _FILE_ASYNC_METHODS:
if meth_name not in dir(async_file):
continue
@@ -175,7 +175,7 @@ async def test_async_methods_wrap(async_file, wrapped):
wrapped.reset_mock()
-async def test_async_methods_match_wrapper(async_file, wrapped):
+async def test_async_methods_match_wrapper(async_file, wrapped) -> None:
for meth_name in _FILE_ASYNC_METHODS:
if meth_name in dir(async_file):
continue
@@ -187,7 +187,7 @@ async def test_async_methods_match_wrapper(async_file, wrapped):
getattr(wrapped, meth_name)
-async def test_open(path):
+async def test_open(path) -> None:
f = await trio.open_file(path, "w")
assert isinstance(f, AsyncIOWrapper)
@@ -195,7 +195,7 @@ async def test_open(path):
await f.aclose()
-async def test_open_context_manager(path):
+async def test_open_context_manager(path) -> None:
async with await trio.open_file(path, "w") as f:
assert isinstance(f, AsyncIOWrapper)
assert not f.closed
@@ -203,7 +203,7 @@ async def test_open_context_manager(path):
assert f.closed
-async def test_async_iter():
+async def test_async_iter() -> None:
async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar"))
expected = list(async_file.wrapped)
result = []
@@ -215,7 +215,7 @@ async def test_async_iter():
assert result == expected
-async def test_aclose_cancelled(path):
+async def test_aclose_cancelled(path) -> None:
with _core.CancelScope() as cscope:
f = await trio.open_file(path, "w")
cscope.cancel()
@@ -229,7 +229,7 @@ async def test_aclose_cancelled(path):
assert f.closed
-async def test_detach_rewraps_asynciobase():
+async def test_detach_rewraps_asynciobase() -> None:
raw = io.BytesIO()
buffered = io.BufferedReader(raw)
diff --git a/trio/_tests/test_highlevel_generic.py b/trio/_tests/test_highlevel_generic.py
index 38bcedee25..64c5697184 100644
--- a/trio/_tests/test_highlevel_generic.py
+++ b/trio/_tests/test_highlevel_generic.py
@@ -9,13 +9,13 @@
class RecordSendStream(SendStream):
record = attr.ib(factory=list)
- async def send_all(self, data):
+ async def send_all(self, data) -> None:
self.record.append(("send_all", data))
- async def wait_send_all_might_not_block(self):
+ async def wait_send_all_might_not_block(self) -> None:
self.record.append("wait_send_all_might_not_block")
- async def aclose(self):
+ async def aclose(self) -> None:
self.record.append("aclose")
@@ -23,14 +23,14 @@ async def aclose(self):
class RecordReceiveStream(ReceiveStream):
record = attr.ib(factory=list)
- async def receive_some(self, max_bytes=None):
+ async def receive_some(self, max_bytes=None) -> None:
self.record.append(("receive_some", max_bytes))
- async def aclose(self):
+ async def aclose(self) -> None:
self.record.append("aclose")
-async def test_StapledStream():
+async def test_StapledStream() -> None:
send_stream = RecordSendStream()
receive_stream = RecordReceiveStream()
stapled = StapledStream(send_stream, receive_stream)
@@ -50,7 +50,7 @@ async def test_StapledStream():
assert send_stream.record == ["aclose"]
send_stream.record.clear()
- async def fake_send_eof():
+ async def fake_send_eof() -> None:
send_stream.record.append("send_eof")
send_stream.send_eof = fake_send_eof
@@ -70,7 +70,7 @@ async def fake_send_eof():
assert send_stream.record == ["aclose"]
-async def test_StapledStream_with_erroring_close():
+async def test_StapledStream_with_erroring_close() -> None:
# Make sure that if one of the aclose methods errors out, then the other
# one still gets called.
class BrokenSendStream(RecordSendStream):
diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py
index f875bfa019..7917e66c50 100644
--- a/trio/_tests/test_highlevel_open_tcp_stream.py
+++ b/trio/_tests/test_highlevel_open_tcp_stream.py
@@ -244,7 +244,7 @@ def __init__(
port: int,
ip_list: Sequence[tuple[str, float, str]],
supported_families: set[AddressFamily],
- ):
+ ) -> None:
# ip_list have to be unique
ip_order = [ip for (ip, _, _) in ip_list]
assert len(set(ip_order)) == len(ip_list)
diff --git a/trio/_tests/test_highlevel_open_unix_stream.py b/trio/_tests/test_highlevel_open_unix_stream.py
index 64a15f9e9d..cf32b8c0fc 100644
--- a/trio/_tests/test_highlevel_open_unix_stream.py
+++ b/trio/_tests/test_highlevel_open_unix_stream.py
@@ -15,7 +15,7 @@ def test_close_on_error():
class CloseMe:
closed = False
- def close(self):
+ def close(self) -> None:
self.closed = True
with close_on_error(CloseMe()) as c:
@@ -29,12 +29,12 @@ def close(self):
@pytest.mark.parametrize("filename", [4, 4.5])
-async def test_open_with_bad_filename_type(filename):
+async def test_open_with_bad_filename_type(filename) -> None:
with pytest.raises(TypeError):
await open_unix_socket(filename)
-async def test_open_bad_socket():
+async def test_open_bad_socket() -> None:
# mktemp is marked as insecure, but that's okay, we don't want the file to
# exist
name = tempfile.mktemp()
@@ -42,7 +42,7 @@ async def test_open_bad_socket():
await open_unix_socket(name)
-async def test_open_unix_socket():
+async def test_open_unix_socket() -> None:
for name_type in [Path, str]:
name = tempfile.mktemp()
serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py
index 65804f4222..beac59f86d 100644
--- a/trio/_tests/test_highlevel_serve_listeners.py
+++ b/trio/_tests/test_highlevel_serve_listeners.py
@@ -32,33 +32,33 @@ async def accept(self):
self.accepted_streams.append(stream)
return stream
- async def aclose(self):
+ async def aclose(self) -> None:
self.closed = True
await trio.lowlevel.checkpoint()
-async def test_serve_listeners_basic():
+async def test_serve_listeners_basic() -> None:
listeners = [MemoryListener(), MemoryListener()]
record = []
- def close_hook():
+ def close_hook() -> None:
# Make sure this is a forceful close
assert trio.current_effective_deadline() == float("-inf")
record.append("closed")
- async def handler(stream):
+ async def handler(stream) -> None:
await stream.send_all(b"123")
assert await stream.receive_some(10) == b"456"
stream.send_stream.close_hook = close_hook
stream.receive_stream.close_hook = close_hook
- async def client(listener):
+ async def client(listener) -> None:
s = await listener.connect()
assert await s.receive_some(10) == b"123"
await s.send_all(b"456")
- async def do_tests(parent_nursery):
+ async def do_tests(parent_nursery) -> None:
async with trio.open_nursery() as nursery:
for listener in listeners:
for _ in range(3):
@@ -82,7 +82,7 @@ async def do_tests(parent_nursery):
assert listener.closed
-async def test_serve_listeners_accept_unrecognized_error():
+async def test_serve_listeners_accept_unrecognized_error() -> None:
for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]:
listener = MemoryListener()
@@ -96,7 +96,7 @@ async def raise_error():
assert excinfo.value is error
-async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog):
+async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog) -> None:
listener = MemoryListener()
async def raise_EMFILE():
@@ -115,10 +115,10 @@ async def raise_EMFILE():
assert record.exc_info[1].errno == errno.EMFILE
-async def test_serve_listeners_connection_nursery(autojump_clock):
+async def test_serve_listeners_connection_nursery(autojump_clock) -> None:
listener = MemoryListener()
- async def handler(stream):
+ async def handler(stream) -> None:
await trio.sleep(1)
class Done(Exception):
diff --git a/trio/_tests/test_highlevel_socket.py b/trio/_tests/test_highlevel_socket.py
index 514b6ad196..61e891e94b 100644
--- a/trio/_tests/test_highlevel_socket.py
+++ b/trio/_tests/test_highlevel_socket.py
@@ -211,7 +211,7 @@ async def test_SocketListener_socket_closed_underfoot() -> None:
async def test_SocketListener_accept_errors() -> None:
class FakeSocket(tsocket.SocketType):
- def __init__(self, events: Sequence[SocketType | BaseException]):
+ def __init__(self, events: Sequence[SocketType | BaseException]) -> None:
self._events = iter(events)
type = tsocket.SOCK_STREAM
diff --git a/trio/_tests/test_highlevel_ssl_helpers.py b/trio/_tests/test_highlevel_ssl_helpers.py
index 89d921476a..9bec4ae2f5 100644
--- a/trio/_tests/test_highlevel_ssl_helpers.py
+++ b/trio/_tests/test_highlevel_ssl_helpers.py
@@ -17,7 +17,7 @@
from .test_ssl import SERVER_CTX, client_ctx # noqa: F401
-async def echo_handler(stream):
+async def echo_handler(stream) -> None:
async with stream:
try:
while True:
@@ -44,7 +44,9 @@ async def getnameinfo(self, *args): # pragma: no cover
# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
# using noqa because linters don't understand how pytest fixtures work.
-async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa: F811
+async def test_open_ssl_over_tcp_stream_and_everything_else(
+ client_ctx, # noqa: F811 # linters doesn't understand fixture
+) -> None:
async with trio.open_nursery() as nursery:
(listener,) = await nursery.start(
partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1")
@@ -97,7 +99,7 @@ async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa
nursery.cancel_scope.cancel()
-async def test_open_ssl_over_tcp_listeners():
+async def test_open_ssl_over_tcp_listeners() -> None:
(listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1")
async with listener:
assert isinstance(listener, trio.SSLListener)
diff --git a/trio/_tests/test_path.py b/trio/_tests/test_path.py
index bfef1aaf2c..1d17029bdf 100644
--- a/trio/_tests/test_path.py
+++ b/trio/_tests/test_path.py
@@ -20,14 +20,14 @@ def method_pair(path, method_name):
return getattr(path, method_name), getattr(async_path, method_name)
-async def test_open_is_async_context_manager(path):
+async def test_open_is_async_context_manager(path) -> None:
async with await path.open("w") as f:
assert isinstance(f, AsyncIOWrapper)
assert f.closed
-async def test_magic():
+async def test_magic() -> None:
path = trio.Path("test")
assert str(path) == "test"
@@ -42,7 +42,7 @@ async def test_magic():
@pytest.mark.parametrize("cls_a,cls_b", cls_pairs)
-async def test_cmp_magic(cls_a, cls_b):
+async def test_cmp_magic(cls_a, cls_b) -> None:
a, b = cls_a(""), cls_b("")
assert a == b
assert not a != b
@@ -69,7 +69,7 @@ async def test_cmp_magic(cls_a, cls_b):
@pytest.mark.parametrize("cls_a,cls_b", cls_pairs)
-async def test_div_magic(cls_a, cls_b):
+async def test_div_magic(cls_a, cls_b) -> None:
a, b = cls_a("a"), cls_b("b")
result = a / b
@@ -81,19 +81,19 @@ async def test_div_magic(cls_a, cls_b):
"cls_a,cls_b", [(trio.Path, pathlib.Path), (trio.Path, trio.Path)]
)
@pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"])
-async def test_hash_magic(cls_a, cls_b, path):
+async def test_hash_magic(cls_a, cls_b, path) -> None:
a, b = cls_a(path), cls_b(path)
assert hash(a) == hash(b)
-async def test_forwarded_properties(path):
+async def test_forwarded_properties(path) -> None:
# use `name` as a representative of forwarded properties
assert "name" in dir(path)
assert path.name == "test"
-async def test_async_method_signature(path):
+async def test_async_method_signature(path) -> None:
# use `resolve` as a representative of wrapped methods
assert path.resolve.__name__ == "resolve"
@@ -103,7 +103,7 @@ async def test_async_method_signature(path):
@pytest.mark.parametrize("method_name", ["is_dir", "is_file"])
-async def test_compare_async_stat_methods(method_name):
+async def test_compare_async_stat_methods(method_name) -> None:
method, async_method = method_pair(".", method_name)
result = method()
@@ -112,13 +112,13 @@ async def test_compare_async_stat_methods(method_name):
assert result == async_result
-async def test_invalid_name_not_wrapped(path):
+async def test_invalid_name_not_wrapped(path) -> None:
with pytest.raises(AttributeError):
getattr(path, "invalid_fake_attr")
@pytest.mark.parametrize("method_name", ["absolute", "resolve"])
-async def test_async_methods_rewrap(method_name):
+async def test_async_methods_rewrap(method_name) -> None:
method, async_method = method_pair(".", method_name)
result = method()
@@ -128,7 +128,7 @@ async def test_async_methods_rewrap(method_name):
assert str(result) == str(async_result)
-async def test_forward_methods_rewrap(path, tmpdir):
+async def test_forward_methods_rewrap(path, tmpdir) -> None:
with_name = path.with_name("foo")
with_suffix = path.with_suffix(".py")
@@ -138,17 +138,17 @@ async def test_forward_methods_rewrap(path, tmpdir):
assert with_suffix == tmpdir.join("test.py")
-async def test_forward_properties_rewrap(path):
+async def test_forward_properties_rewrap(path) -> None:
assert isinstance(path.parent, trio.Path)
-async def test_forward_methods_without_rewrap(path, tmpdir):
+async def test_forward_methods_without_rewrap(path, tmpdir) -> None:
path = await path.parent.resolve()
assert path.as_uri().startswith("file:///")
-async def test_repr():
+async def test_repr() -> None:
path = trio.Path(".")
assert repr(path) == "trio.Path('.')"
@@ -164,30 +164,30 @@ class MockWrapper:
_wraps = MockWrapped
-async def test_type_forwards_unsupported():
+async def test_type_forwards_unsupported() -> None:
with pytest.raises(TypeError):
Type.generate_forwards(MockWrapper, {})
-async def test_type_wraps_unsupported():
+async def test_type_wraps_unsupported() -> None:
with pytest.raises(TypeError):
Type.generate_wraps(MockWrapper, {})
-async def test_type_forwards_private():
+async def test_type_forwards_private() -> None:
Type.generate_forwards(MockWrapper, {"unsupported": None})
assert not hasattr(MockWrapper, "_private")
-async def test_type_wraps_private():
+async def test_type_wraps_private() -> None:
Type.generate_wraps(MockWrapper, {"unsupported": None})
assert not hasattr(MockWrapper, "_private")
@pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath])
-async def test_path_wraps_path(path, meth):
+async def test_path_wraps_path(path, meth) -> None:
wrapped = await path.absolute()
result = meth(path, wrapped)
if result is None:
@@ -196,17 +196,17 @@ async def test_path_wraps_path(path, meth):
assert wrapped == result
-async def test_path_nonpath():
+async def test_path_nonpath() -> None:
with pytest.raises(TypeError):
trio.Path(1)
-async def test_open_file_can_open_path(path):
+async def test_open_file_can_open_path(path) -> None:
async with await trio.open_file(path, "w") as f:
assert f.name == os.fspath(path)
-async def test_globmethods(path):
+async def test_globmethods(path) -> None:
# Populate a directory tree
await path.mkdir()
await (path / "foo").mkdir()
@@ -235,7 +235,7 @@ async def test_globmethods(path):
assert entries == {"_bar.txt", "bar.txt"}
-async def test_iterdir(path):
+async def test_iterdir(path) -> None:
# Populate a directory
await path.mkdir()
await (path / "foo").mkdir()
@@ -249,7 +249,7 @@ async def test_iterdir(path):
assert entries == {"bar.txt", "foo"}
-async def test_classmethods():
+async def test_classmethods() -> None:
assert isinstance(await trio.Path.home(), trio.Path)
# pathlib.Path has only two classmethods
diff --git a/trio/_tests/test_scheduler_determinism.py b/trio/_tests/test_scheduler_determinism.py
index e2d3167e45..1c438f136c 100644
--- a/trio/_tests/test_scheduler_determinism.py
+++ b/trio/_tests/test_scheduler_determinism.py
@@ -5,7 +5,7 @@ async def scheduler_trace():
"""Returns a scheduler-dependent value we can use to check determinism."""
trace = []
- async def tracer(name):
+ async def tracer(name) -> None:
for i in range(50):
trace.append((name, i))
await trio.sleep(0)
@@ -17,7 +17,7 @@ async def tracer(name):
return tuple(trace)
-def test_the_trio_scheduler_is_not_deterministic():
+def test_the_trio_scheduler_is_not_deterministic() -> None:
# At least, not yet. See https://github.com/python-trio/trio/issues/32
traces = []
for _ in range(10):
@@ -25,7 +25,7 @@ def test_the_trio_scheduler_is_not_deterministic():
assert len(set(traces)) == len(traces)
-def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch):
+def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch) -> None:
monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True)
traces = []
for _ in range(10):
diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py
index 40ffefb8cd..6b1753c82b 100644
--- a/trio/_tests/test_socket.py
+++ b/trio/_tests/test_socket.py
@@ -39,7 +39,7 @@
class MonkeypatchedGAI:
- def __init__(self, orig_getaddrinfo: Callable[..., getaddrinfoResponse]):
+ def __init__(self, orig_getaddrinfo: Callable[..., getaddrinfoResponse]) -> None:
self._orig_getaddrinfo = orig_getaddrinfo
self._responses: dict[tuple[Any, ...], getaddrinfoResponse | str] = {}
self.record: list[tuple[Any, ...]] = []
diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py
index d92e35b0d9..06ac2e694c 100644
--- a/trio/_tests/test_ssl.py
+++ b/trio/_tests/test_ssl.py
@@ -167,7 +167,7 @@ async def ssl_echo_server(client_ctx, **kwargs):
# Doesn't inherit from Stream because I left out the methods that we don't
# actually need.
class PyOpenSSLEchoStream:
- def __init__(self, sleeper=None):
+ def __init__(self, sleeper=None) -> None:
ctx = SSL.Context(SSL.SSLv23_METHOD)
# TLS 1.3 removes renegotiation support. Which is great for them, but
# we still have to support versions before that, and that means we
@@ -217,31 +217,31 @@ def __init__(self, sleeper=None):
if sleeper is None:
- async def no_op_sleeper(_):
+ async def no_op_sleeper(_) -> None:
return
self.sleeper = no_op_sleeper
else:
self.sleeper = sleeper
- async def aclose(self):
+ async def aclose(self) -> None:
self._conn.bio_shutdown()
def renegotiate_pending(self):
return self._conn.renegotiate_pending()
- def renegotiate(self):
+ def renegotiate(self) -> None:
# Returns false if a renegotiation is already in progress, meaning
# nothing happens.
assert self._conn.renegotiate()
- async def wait_send_all_might_not_block(self):
+ async def wait_send_all_might_not_block(self) -> None:
with self._send_all_conflict_detector:
await _core.checkpoint()
await _core.checkpoint()
await self.sleeper("wait_send_all_might_not_block")
- async def send_all(self, data):
+ async def send_all(self, data) -> None:
print(" --> transport_stream.send_all")
with self._send_all_conflict_detector:
await _core.checkpoint()
@@ -320,7 +320,7 @@ async def receive_some(self, nbytes=None):
print(" <-- transport_stream.receive_some finished")
-async def test_PyOpenSSLEchoStream_gives_resource_busy_errors():
+async def test_PyOpenSSLEchoStream_gives_resource_busy_errors() -> None:
# Make sure that PyOpenSSLEchoStream complains if two tasks call send_all
# at the same time, or ditto for receive_some. The tricky cases where SSLStream
# might accidentally do this are during renegotiation, which we test using
@@ -395,7 +395,7 @@ def ssl_lockstep_stream_pair(client_ctx, **kwargs):
# Simple smoke test for handshake/send/receive/shutdown talking to a
# synchronous server, plus make sure that we do the bare minimum of
# certificate checking (even though this is really Python's responsibility)
-async def test_ssl_client_basics(client_ctx):
+async def test_ssl_client_basics(client_ctx) -> None:
# Everything OK
async with ssl_echo_server(client_ctx) as s:
assert not s.server_side
@@ -421,7 +421,7 @@ async def test_ssl_client_basics(client_ctx):
assert isinstance(excinfo.value.__cause__, ssl.CertificateError)
-async def test_ssl_server_basics(client_ctx):
+async def test_ssl_server_basics(client_ctx) -> None:
a, b = stdlib_socket.socketpair()
with a, b:
server_sock = tsocket.from_stdlib_socket(b)
@@ -430,7 +430,7 @@ async def test_ssl_server_basics(client_ctx):
)
assert server_transport.server_side
- def client():
+ def client() -> None:
with client_ctx.wrap_socket(
a, server_hostname="trio-test-1.example.org"
) as client_sock:
@@ -451,7 +451,7 @@ def client():
t.join()
-async def test_attributes(client_ctx):
+async def test_attributes(client_ctx) -> None:
async with ssl_echo_server_raw(expect_fail=True) as sock:
good_ctx = client_ctx
bad_ctx = ssl.create_default_context()
@@ -520,7 +520,7 @@ async def test_attributes(client_ctx):
# I begin to see why HTTP/2 forbids renegotiation and TLS 1.3 removes it...
-async def test_full_duplex_basics(client_ctx):
+async def test_full_duplex_basics(client_ctx) -> None:
CHUNKS = 30
CHUNK_SIZE = 32768
EXPECTED = CHUNKS * CHUNK_SIZE
@@ -528,7 +528,7 @@ async def test_full_duplex_basics(client_ctx):
sent = bytearray()
received = bytearray()
- async def sender(s):
+ async def sender(s) -> None:
nonlocal sent
for i in range(CHUNKS):
print(i)
@@ -536,7 +536,7 @@ async def sender(s):
sent += chunk
await s.send_all(chunk)
- async def receiver(s):
+ async def receiver(s) -> None:
nonlocal received
while len(received) < EXPECTED:
chunk = await s.receive_some(CHUNK_SIZE // 2)
@@ -557,7 +557,7 @@ async def receiver(s):
assert sent == received
-async def test_renegotiation_simple(client_ctx):
+async def test_renegotiation_simple(client_ctx) -> None:
with virtual_ssl_echo_server(client_ctx) as s:
await s.do_handshake()
@@ -576,7 +576,7 @@ async def test_renegotiation_simple(client_ctx):
@slow
-async def test_renegotiation_randomized(mock_clock, client_ctx):
+async def test_renegotiation_randomized(mock_clock, client_ctx) -> None:
# The only blocking things in this function are our random sleeps, so 0 is
# a good threshold.
mock_clock.autojump_threshold = 0
@@ -585,10 +585,10 @@ async def test_renegotiation_randomized(mock_clock, client_ctx):
r = random.Random(0)
- async def sleeper(_):
+ async def sleeper(_) -> None:
await trio.sleep(r.uniform(0, 10))
- async def clear():
+ async def clear() -> None:
while s.transport_stream.renegotiate_pending():
with assert_checkpoints():
await send(b"-")
@@ -596,13 +596,13 @@ async def clear():
await expect(b"-")
print("-- clear --")
- async def send(byte):
+ async def send(byte) -> None:
await s.transport_stream.sleeper("outer send")
print("calling SSLStream.send_all", byte)
with assert_checkpoints():
await s.send_all(byte)
- async def expect(expected):
+ async def expect(expected) -> None:
await s.transport_stream.sleeper("expect")
print("calling SSLStream.receive_some, expecting", expected)
assert len(expected) == 1
@@ -648,13 +648,13 @@ async def expect(expected):
# and wait_send_all_might_not_block comes in.
# Our receive_some() call will get stuck when it hits send_all
- async def sleeper_with_slow_send_all(method):
+ async def sleeper_with_slow_send_all(method) -> None:
if method == "send_all":
await trio.sleep(100000)
# And our wait_send_all_might_not_block call will give it time to get
# stuck, and then start
- async def sleep_then_wait_writable():
+ async def sleep_then_wait_writable() -> None:
await trio.sleep(1000)
await s.wait_send_all_might_not_block()
@@ -672,7 +672,7 @@ async def sleep_then_wait_writable():
# 2) Same, but now wait_send_all_might_not_block is stuck when
# receive_some tries to send.
- async def sleeper_with_slow_wait_writable_and_expect(method):
+ async def sleeper_with_slow_wait_writable_and_expect(method) -> None:
if method == "wait_send_all_might_not_block":
await trio.sleep(100000)
elif method == "expect":
@@ -692,16 +692,16 @@ async def sleeper_with_slow_wait_writable_and_expect(method):
await s.aclose()
-async def test_resource_busy_errors(client_ctx):
- async def do_send_all():
+async def test_resource_busy_errors(client_ctx) -> None:
+ async def do_send_all() -> None:
with assert_checkpoints():
await s.send_all(b"x")
- async def do_receive_some():
+ async def do_receive_some() -> None:
with assert_checkpoints():
await s.receive_some(1)
- async def do_wait_send_all_might_not_block():
+ async def do_wait_send_all_might_not_block() -> None:
with assert_checkpoints():
await s.wait_send_all_might_not_block()
@@ -734,11 +734,11 @@ async def do_wait_send_all_might_not_block():
assert "another task" in str(excinfo.value)
-async def test_wait_writable_calls_underlying_wait_writable():
+async def test_wait_writable_calls_underlying_wait_writable() -> None:
record = []
class NotAStream:
- async def wait_send_all_might_not_block(self):
+ async def wait_send_all_might_not_block(self) -> None:
record.append("ok")
ctx = ssl.create_default_context()
@@ -751,7 +751,7 @@ async def wait_send_all_might_not_block(self):
os.name == "nt" and sys.version_info >= (3, 10),
reason="frequently fails on Windows + Python 3.10",
)
-async def test_checkpoints(client_ctx):
+async def test_checkpoints(client_ctx) -> None:
async with ssl_echo_server(client_ctx) as s:
with assert_checkpoints():
await s.do_handshake()
@@ -780,7 +780,7 @@ async def test_checkpoints(client_ctx):
await s.aclose()
-async def test_send_all_empty_string(client_ctx):
+async def test_send_all_empty_string(client_ctx) -> None:
async with ssl_echo_server(client_ctx) as s:
await s.do_handshake()
@@ -797,7 +797,7 @@ async def test_send_all_empty_string(client_ctx):
@pytest.mark.parametrize("https_compatible", [False, True])
-async def test_SSLStream_generic(client_ctx, https_compatible):
+async def test_SSLStream_generic(client_ctx, https_compatible) -> None:
async def stream_maker():
return ssl_memory_stream_pair(
client_ctx,
@@ -821,14 +821,14 @@ async def clogged_stream_maker():
await check_two_way_stream(stream_maker, clogged_stream_maker)
-async def test_unwrap(client_ctx):
+async def test_unwrap(client_ctx) -> None:
client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
client_transport = client_ssl.transport_stream
server_transport = server_ssl.transport_stream
seq = Sequencer()
- async def client():
+ async def client() -> None:
await client_ssl.do_handshake()
await client_ssl.send_all(b"x")
assert await client_ssl.receive_some(1) == b"y"
@@ -855,7 +855,7 @@ async def client():
client_transport.send_stream.send_all_hook = send_all_hook
await client_transport.send_stream.send_all_hook()
- async def server():
+ async def server() -> None:
await server_ssl.do_handshake()
assert await server_ssl.receive_some(1) == b"x"
await server_ssl.send_all(b"y")
@@ -875,7 +875,7 @@ async def server():
nursery.start_soon(server)
-async def test_closing_nice_case(client_ctx):
+async def test_closing_nice_case(client_ctx) -> None:
# the nice case: graceful closes all around
client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
@@ -883,11 +883,11 @@ async def test_closing_nice_case(client_ctx):
# Both the handshake and the close require back-and-forth discussion, so
# we need to run them concurrently
- async def client_closer():
+ async def client_closer() -> None:
with assert_checkpoints():
await client_ssl.aclose()
- async def server_closer():
+ async def server_closer() -> None:
assert await server_ssl.receive_some(10) == b""
assert await server_ssl.receive_some(10) == b""
with assert_checkpoints():
@@ -926,7 +926,7 @@ async def server_closer():
# the other side
client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
- async def expect_eof_server():
+ async def expect_eof_server() -> None:
with assert_checkpoints():
assert await server_ssl.receive_some(10) == b""
with assert_checkpoints():
@@ -937,7 +937,7 @@ async def expect_eof_server():
nursery.start_soon(expect_eof_server)
-async def test_send_all_fails_in_the_middle(client_ctx):
+async def test_send_all_fails_in_the_middle(client_ctx) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
async with _core.open_nursery() as nursery:
@@ -957,7 +957,7 @@ async def bad_hook():
closed = 0
- def close_hook():
+ def close_hook() -> None:
nonlocal closed
closed += 1
@@ -968,7 +968,7 @@ def close_hook():
assert closed == 2
-async def test_ssl_over_ssl(client_ctx):
+async def test_ssl_over_ssl(client_ctx) -> None:
client_0, server_0 = memory_stream_pair()
client_1 = SSLStream(
@@ -981,11 +981,11 @@ async def test_ssl_over_ssl(client_ctx):
)
server_2 = SSLStream(server_1, SERVER_CTX, server_side=True)
- async def client():
+ async def client() -> None:
await client_2.send_all(b"hi")
assert await client_2.receive_some(10) == b"bye"
- async def server():
+ async def server() -> None:
assert await server_2.receive_some(10) == b"hi"
await server_2.send_all(b"bye")
@@ -994,7 +994,7 @@ async def server():
nursery.start_soon(server)
-async def test_ssl_bad_shutdown(client_ctx):
+async def test_ssl_bad_shutdown(client_ctx) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
async with _core.open_nursery() as nursery:
@@ -1011,7 +1011,7 @@ async def test_ssl_bad_shutdown(client_ctx):
await server.aclose()
-async def test_ssl_bad_shutdown_but_its_ok(client_ctx):
+async def test_ssl_bad_shutdown_but_its_ok(client_ctx) -> None:
client, server = ssl_memory_stream_pair(
client_ctx,
server_kwargs={"https_compatible": True},
@@ -1031,7 +1031,7 @@ async def test_ssl_bad_shutdown_but_its_ok(client_ctx):
await server.aclose()
-async def test_ssl_handshake_failure_during_aclose():
+async def test_ssl_handshake_failure_during_aclose() -> None:
# Weird scenario: aclose() triggers an automatic handshake, and this
# fails. This also exercises a bit of code in aclose() that was otherwise
# uncovered, for re-raising exceptions after calling aclose_forcefully on
@@ -1050,7 +1050,7 @@ async def test_ssl_handshake_failure_during_aclose():
await s.aclose()
-async def test_ssl_only_closes_stream_once(client_ctx):
+async def test_ssl_only_closes_stream_once(client_ctx) -> None:
# We used to have a bug where if transport_stream.aclose() raised an
# error, we would call it again. This checks that that's fixed.
client, server = ssl_memory_stream_pair(client_ctx)
@@ -1075,7 +1075,7 @@ def close_hook():
assert transport_close_count == 1
-async def test_ssl_https_compatibility_disagreement(client_ctx):
+async def test_ssl_https_compatibility_disagreement(client_ctx) -> None:
client, server = ssl_memory_stream_pair(
client_ctx,
server_kwargs={"https_compatible": False},
@@ -1088,7 +1088,7 @@ async def test_ssl_https_compatibility_disagreement(client_ctx):
# client is in HTTPS-mode, server is not
# so client doing graceful_shutdown causes an error on server
- async def receive_and_expect_error():
+ async def receive_and_expect_error() -> None:
with pytest.raises(BrokenResourceError) as excinfo:
await server.receive_some(10)
@@ -1099,14 +1099,14 @@ async def receive_and_expect_error():
nursery.start_soon(receive_and_expect_error)
-async def test_https_mode_eof_before_handshake(client_ctx):
+async def test_https_mode_eof_before_handshake(client_ctx) -> None:
client, server = ssl_memory_stream_pair(
client_ctx,
server_kwargs={"https_compatible": True},
client_kwargs={"https_compatible": True},
)
- async def server_expect_clean_eof():
+ async def server_expect_clean_eof() -> None:
assert await server.receive_some(10) == b""
async with _core.open_nursery() as nursery:
@@ -1114,7 +1114,7 @@ async def server_expect_clean_eof():
nursery.start_soon(server_expect_clean_eof)
-async def test_send_error_during_handshake(client_ctx):
+async def test_send_error_during_handshake(client_ctx) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
async def bad_hook():
@@ -1131,7 +1131,7 @@ async def bad_hook():
await client.do_handshake()
-async def test_receive_error_during_handshake(client_ctx):
+async def test_receive_error_during_handshake(client_ctx) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
async def bad_hook():
@@ -1139,7 +1139,7 @@ async def bad_hook():
client.transport_stream.receive_stream.receive_some_hook = bad_hook
- async def client_side(cancel_scope):
+ async def client_side(cancel_scope) -> None:
with pytest.raises(KeyError):
with assert_checkpoints():
await client.do_handshake()
@@ -1154,7 +1154,7 @@ async def client_side(cancel_scope):
await client.do_handshake()
-async def test_selected_alpn_protocol_before_handshake(client_ctx):
+async def test_selected_alpn_protocol_before_handshake(client_ctx) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
with pytest.raises(NeedHandshakeError):
@@ -1164,7 +1164,7 @@ async def test_selected_alpn_protocol_before_handshake(client_ctx):
server.selected_alpn_protocol()
-async def test_selected_alpn_protocol_when_not_set(client_ctx):
+async def test_selected_alpn_protocol_when_not_set(client_ctx) -> None:
# ALPN protocol still returns None when it's not set,
# instead of raising an exception
client, server = ssl_memory_stream_pair(client_ctx)
@@ -1179,7 +1179,7 @@ async def test_selected_alpn_protocol_when_not_set(client_ctx):
assert client.selected_alpn_protocol() == server.selected_alpn_protocol()
-async def test_selected_npn_protocol_before_handshake(client_ctx):
+async def test_selected_npn_protocol_before_handshake(client_ctx) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
with pytest.raises(NeedHandshakeError):
@@ -1193,7 +1193,7 @@ async def test_selected_npn_protocol_before_handshake(client_ctx):
r"ignore: ssl module. NPN is deprecated, use ALPN instead:UserWarning",
r"ignore:ssl NPN is deprecated, use ALPN instead:DeprecationWarning",
)
-async def test_selected_npn_protocol_when_not_set(client_ctx):
+async def test_selected_npn_protocol_when_not_set(client_ctx) -> None:
# NPN protocol still returns None when it's not set,
# instead of raising an exception
client, server = ssl_memory_stream_pair(client_ctx)
@@ -1208,7 +1208,7 @@ async def test_selected_npn_protocol_when_not_set(client_ctx):
assert client.selected_npn_protocol() == server.selected_npn_protocol()
-async def test_get_channel_binding_before_handshake(client_ctx):
+async def test_get_channel_binding_before_handshake(client_ctx) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
with pytest.raises(NeedHandshakeError):
@@ -1218,7 +1218,7 @@ async def test_get_channel_binding_before_handshake(client_ctx):
server.get_channel_binding()
-async def test_get_channel_binding_after_handshake(client_ctx):
+async def test_get_channel_binding_after_handshake(client_ctx) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
async with _core.open_nursery() as nursery:
@@ -1231,7 +1231,7 @@ async def test_get_channel_binding_after_handshake(client_ctx):
assert client.get_channel_binding() == server.get_channel_binding()
-async def test_getpeercert(client_ctx):
+async def test_getpeercert(client_ctx) -> None:
# Make sure we're not affected by https://bugs.python.org/issue29334
client, server = ssl_memory_stream_pair(client_ctx)
@@ -1244,7 +1244,7 @@ async def test_getpeercert(client_ctx):
assert ("DNS", "trio-test-1.example.org") in client.getpeercert()["subjectAltName"]
-async def test_SSLListener(client_ctx):
+async def test_SSLListener(client_ctx) -> None:
async def setup(**kwargs):
listen_sock = tsocket.socket()
await listen_sock.bind(("127.0.0.1", 0))
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index 17cf740012..2d805a1814 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -97,7 +97,7 @@ async def run_process_in_nursery(*args, **kwargs):
@background_process_param
-async def test_basic(background_process):
+async def test_basic(background_process) -> None:
async with background_process(EXIT_TRUE) as proc:
await proc.wait()
assert isinstance(proc, Process)
@@ -114,7 +114,7 @@ async def test_basic(background_process):
@background_process_param
-async def test_auto_update_returncode(background_process):
+async def test_auto_update_returncode(background_process) -> None:
async with background_process(SLEEP(9999)) as p:
assert p.returncode is None
assert "running" in repr(p)
@@ -127,7 +127,7 @@ async def test_auto_update_returncode(background_process):
@background_process_param
-async def test_multi_wait(background_process):
+async def test_multi_wait(background_process) -> None:
async with background_process(SLEEP(10)) as proc:
# Check that wait (including multi-wait) tolerates being cancelled
async with _core.open_nursery() as nursery:
@@ -147,7 +147,7 @@ async def test_multi_wait(background_process):
# Test for deprecated 'async with process:' semantics
-async def test_async_with_basics_deprecated(recwarn):
+async def test_async_with_basics_deprecated(recwarn) -> None:
async with await open_process(
CAT, stdin=subprocess.PIPE, stdout=subprocess.PIPE
) as proc:
@@ -160,7 +160,7 @@ async def test_async_with_basics_deprecated(recwarn):
# Test for deprecated 'async with process:' semantics
-async def test_kill_when_context_cancelled(recwarn):
+async def test_kill_when_context_cancelled(recwarn) -> None:
with move_on_after(100) as scope:
async with await open_process(SLEEP(10)) as proc:
assert proc.poll() is None
@@ -181,7 +181,7 @@ async def test_kill_when_context_cancelled(recwarn):
@background_process_param
-async def test_pipes(background_process):
+async def test_pipes(background_process) -> None:
async with background_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=subprocess.PIPE,
@@ -190,11 +190,11 @@ async def test_pipes(background_process):
) as proc:
msg = b"the quick brown fox jumps over the lazy dog"
- async def feed_input():
+ async def feed_input() -> None:
await proc.stdin.send_all(msg)
await proc.stdin.aclose()
- async def check_output(stream, expected):
+ async def check_output(stream, expected) -> None:
seen = bytearray()
async for chunk in stream:
seen += chunk
@@ -212,7 +212,7 @@ async def check_output(stream, expected):
@background_process_param
-async def test_interactive(background_process):
+async def test_interactive(background_process) -> None:
# Test some back-and-forth with a subprocess. This one works like so:
# in: 32\n
# out: 0000...0000\n (32 zeroes)
@@ -241,10 +241,10 @@ async def test_interactive(background_process):
) as proc:
newline = b"\n" if posix else b"\r\n"
- async def expect(idx, request):
+ async def expect(idx, request) -> None:
async with _core.open_nursery() as nursery:
- async def drain_one(stream, count, digit):
+ async def drain_one(stream, count, digit) -> None:
while count > 0:
result = await stream.receive_some(count)
assert result == (f"{digit}".encode() * len(result))
@@ -279,7 +279,7 @@ async def drain_one(stream, count, digit):
assert proc.returncode == 0
-async def test_run():
+async def test_run() -> None:
data = bytes(random.randint(0, 255) for _ in range(2**18))
result = await run_process(
@@ -322,7 +322,7 @@ async def test_run():
await run_process(CAT, capture_stderr=True, stderr=None)
-async def test_run_check():
+async def test_run_check() -> None:
cmd = python("sys.stderr.buffer.write(b'test\\n'); sys.exit(1)")
with pytest.raises(subprocess.CalledProcessError) as excinfo:
await run_process(cmd, stdin=subprocess.DEVNULL, capture_stderr=True)
@@ -341,7 +341,7 @@ async def test_run_check():
@skip_if_fbsd_pipes_broken
-async def test_run_with_broken_pipe():
+async def test_run_with_broken_pipe() -> None:
result = await run_process(
[sys.executable, "-c", "import sys; sys.stdin.close()"], stdin=b"x" * 131072
)
@@ -350,7 +350,7 @@ async def test_run_with_broken_pipe():
@background_process_param
-async def test_stderr_stdout(background_process):
+async def test_stderr_stdout(background_process) -> None:
async with background_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=subprocess.PIPE,
@@ -416,7 +416,7 @@ async def test_stderr_stdout(background_process):
os.close(r)
-async def test_errors():
+async def test_errors() -> None:
with pytest.raises(TypeError) as excinfo:
await open_process(["ls"], encoding="utf-8")
assert "unbuffered byte streams" in str(excinfo.value)
@@ -430,8 +430,8 @@ async def test_errors():
@background_process_param
-async def test_signals(background_process):
- async def test_one_signal(send_it, signum):
+async def test_signals(background_process) -> None:
+ async def test_one_signal(send_it, signum) -> None:
with move_on_after(1.0) as scope:
async with background_process(SLEEP(3600)) as proc:
send_it(proc)
@@ -457,7 +457,7 @@ async def test_one_signal(send_it, signum):
@pytest.mark.skipif(not posix, reason="POSIX specific")
@background_process_param
-async def test_wait_reapable_fails(background_process):
+async def test_wait_reapable_fails(background_process) -> None:
old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
try:
# With SIGCHLD disabled, the wait() syscall will wait for the
@@ -476,7 +476,7 @@ async def test_wait_reapable_fails(background_process):
@slow
-def test_waitid_eintr():
+def test_waitid_eintr() -> None:
# This only matters on PyPy (where we're coding EINTR handling
# ourselves) but the test works on all waitid platforms.
from .._subprocess_platform import wait_child_exiting
@@ -488,7 +488,7 @@ def test_waitid_eintr():
got_alarm = False
sleeper = subprocess.Popen(["sleep", "3600"])
- def on_alarm(sig, frame):
+ def on_alarm(sig, frame) -> None:
nonlocal got_alarm
got_alarm = True
sleeper.kill()
@@ -507,10 +507,10 @@ def on_alarm(sig, frame):
signal.signal(signal.SIGALRM, old_sigalrm)
-async def test_custom_deliver_cancel():
+async def test_custom_deliver_cancel() -> None:
custom_deliver_cancel_called = False
- async def custom_deliver_cancel(proc):
+ async def custom_deliver_cancel(proc) -> None:
nonlocal custom_deliver_cancel_called
custom_deliver_cancel_called = True
proc.terminate()
@@ -531,7 +531,7 @@ async def custom_deliver_cancel(proc):
assert custom_deliver_cancel_called
-async def test_warn_on_failed_cancel_terminate(monkeypatch):
+async def test_warn_on_failed_cancel_terminate(monkeypatch) -> None:
original_terminate = Process.terminate
def broken_terminate(self):
@@ -548,7 +548,7 @@ def broken_terminate(self):
@pytest.mark.skipif(not posix, reason="posix only")
-async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch):
+async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch) -> None:
monkeypatch.setattr(Process, "terminate", lambda *args: None)
with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"):
@@ -560,7 +560,7 @@ async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch):
# the background_process_param exercises a lot of run_process cases, but it uses
# check=False, so lets have a test that uses check=True as well
-async def test_run_process_background_fail():
+async def test_run_process_background_fail() -> None:
with pytest.raises(subprocess.CalledProcessError):
async with _core.open_nursery() as nursery:
proc = await nursery.start(run_process, EXIT_FALSE)
@@ -571,7 +571,7 @@ async def test_run_process_background_fail():
not SyncPath("/dev/fd").exists(),
reason="requires a way to iterate through open files",
)
-async def test_for_leaking_fds():
+async def test_for_leaking_fds() -> None:
starting_fds = set(SyncPath("/dev/fd").iterdir())
await run_process(EXIT_TRUE)
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
@@ -586,7 +586,7 @@ async def test_for_leaking_fds():
# regression test for #2209
-async def test_subprocess_pidfd_unnotified():
+async def test_subprocess_pidfd_unnotified() -> None:
noticed_exit = None
async def wait_and_tell(proc: Process) -> None:
diff --git a/trio/_tests/test_sync.py b/trio/_tests/test_sync.py
index 7de42b86f9..4325e33a6c 100644
--- a/trio/_tests/test_sync.py
+++ b/trio/_tests/test_sync.py
@@ -8,7 +8,7 @@
from ..testing import assert_checkpoints, wait_all_tasks_blocked
-async def test_Event():
+async def test_Event() -> None:
e = Event()
assert not e.is_set()
assert e.statistics().tasks_waiting == 0
@@ -22,7 +22,7 @@ async def test_Event():
record = []
- async def child():
+ async def child() -> None:
record.append("sleeping")
await e.wait()
record.append("woken")
@@ -38,7 +38,7 @@ async def child():
assert record == ["sleeping", "sleeping", "woken", "woken"]
-async def test_CapacityLimiter():
+async def test_CapacityLimiter() -> None:
with pytest.raises(TypeError):
CapacityLimiter(1.0)
with pytest.raises(ValueError):
@@ -107,7 +107,7 @@ async def test_CapacityLimiter():
c.release_on_behalf_of("value 1")
-async def test_CapacityLimiter_inf():
+async def test_CapacityLimiter_inf() -> None:
from math import inf
c = CapacityLimiter(inf)
@@ -123,7 +123,7 @@ async def test_CapacityLimiter_inf():
assert c.available_tokens == inf
-async def test_CapacityLimiter_change_total_tokens():
+async def test_CapacityLimiter_change_total_tokens() -> None:
c = CapacityLimiter(2)
with pytest.raises(TypeError):
@@ -160,7 +160,7 @@ async def test_CapacityLimiter_change_total_tokens():
# regression test for issue #548
-async def test_CapacityLimiter_memleak_548():
+async def test_CapacityLimiter_memleak_548() -> None:
limiter = CapacityLimiter(total_tokens=1)
await limiter.acquire()
@@ -174,7 +174,7 @@ async def test_CapacityLimiter_memleak_548():
assert len(limiter._pending_borrowers) == 0
-async def test_Semaphore():
+async def test_Semaphore() -> None:
with pytest.raises(TypeError):
Semaphore(1.0)
with pytest.raises(ValueError):
@@ -204,7 +204,7 @@ async def test_Semaphore():
record = []
- async def do_acquire(s):
+ async def do_acquire(s) -> None:
record.append("started")
await s.acquire()
record.append("finished")
@@ -222,7 +222,7 @@ async def do_acquire(s):
assert record == ["started", "finished"]
-async def test_Semaphore_bounded():
+async def test_Semaphore_bounded() -> None:
with pytest.raises(TypeError):
Semaphore(1, max_value=1.0)
with pytest.raises(ValueError):
@@ -240,7 +240,7 @@ async def test_Semaphore_bounded():
@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__)
-async def test_Lock_and_StrictFIFOLock(lockcls):
+async def test_Lock_and_StrictFIFOLock(lockcls) -> None:
l = lockcls() # noqa
assert not l.locked()
@@ -277,7 +277,7 @@ async def test_Lock_and_StrictFIFOLock(lockcls):
holder_task = None
- async def holder():
+ async def holder() -> None:
nonlocal holder_task
holder_task = _core.current_task()
async with l:
@@ -315,7 +315,7 @@ async def holder():
assert statistics.tasks_waiting == 0
-async def test_Condition():
+async def test_Condition() -> None:
with pytest.raises(TypeError):
Condition(Semaphore(1))
with pytest.raises(TypeError):
@@ -349,7 +349,7 @@ async def test_Condition():
finished_waiters = set()
- async def waiter(i):
+ async def waiter(i) -> None:
async with c:
await c.wait()
finished_waiters.add(i)
@@ -407,56 +407,56 @@ async def waiter(i):
class ChannelLock1(AsyncContextManagerMixin):
- def __init__(self, capacity):
+ def __init__(self, capacity) -> None:
self.s, self.r = open_memory_channel(capacity)
for _ in range(capacity - 1):
self.s.send_nowait(None)
- def acquire_nowait(self):
+ def acquire_nowait(self) -> None:
self.s.send_nowait(None)
- async def acquire(self):
+ async def acquire(self) -> None:
await self.s.send(None)
- def release(self):
+ def release(self) -> None:
self.r.receive_nowait()
class ChannelLock2(AsyncContextManagerMixin):
- def __init__(self):
+ def __init__(self) -> None:
self.s, self.r = open_memory_channel(10)
self.s.send_nowait(None)
- def acquire_nowait(self):
+ def acquire_nowait(self) -> None:
self.r.receive_nowait()
- async def acquire(self):
+ async def acquire(self) -> None:
await self.r.receive()
- def release(self):
+ def release(self) -> None:
self.s.send_nowait(None)
class ChannelLock3(AsyncContextManagerMixin):
- def __init__(self):
+ def __init__(self) -> None:
self.s, self.r = open_memory_channel(0)
# self.acquired is true when one task acquires the lock and
# only becomes false when it's released and no tasks are
# waiting to acquire.
self.acquired = False
- def acquire_nowait(self):
+ def acquire_nowait(self) -> None:
assert not self.acquired
self.acquired = True
- async def acquire(self):
+ async def acquire(self) -> None:
if self.acquired:
await self.s.send(None)
else:
self.acquired = True
await _core.checkpoint()
- def release(self):
+ def release(self) -> None:
try:
self.r.receive_nowait()
except _core.WouldBlock:
@@ -493,13 +493,13 @@ def release(self):
# Spawn a bunch of workers that take a lock and then yield; make sure that
# only one worker is ever in the critical section at a time.
@generic_lock_test
-async def test_generic_lock_exclusion(lock_factory):
+async def test_generic_lock_exclusion(lock_factory) -> None:
LOOPS = 10
WORKERS = 5
in_critical_section = False
acquires = 0
- async def worker(lock_like):
+ async def worker(lock_like) -> None:
nonlocal in_critical_section, acquires
for _ in range(LOOPS):
async with lock_like:
@@ -522,12 +522,12 @@ async def worker(lock_like):
# Several workers queue on the same lock; make sure they each get it, in
# order.
@generic_lock_test
-async def test_generic_lock_fifo_fairness(lock_factory):
+async def test_generic_lock_fifo_fairness(lock_factory) -> None:
initial_order = []
record = []
LOOPS = 5
- async def loopy(name, lock_like):
+ async def loopy(name, lock_like) -> None:
# Record the order each task was initially scheduled in
initial_order.append(name)
for _ in range(LOOPS):
@@ -546,12 +546,12 @@ async def loopy(name, lock_like):
@generic_lock_test
-async def test_generic_lock_acquire_nowait_blocks_acquire(lock_factory):
+async def test_generic_lock_acquire_nowait_blocks_acquire(lock_factory) -> None:
lock_like = lock_factory()
record = []
- async def lock_taker():
+ async def lock_taker() -> None:
record.append("started")
async with lock_like:
pass
diff --git a/trio/_tests/test_testing.py b/trio/_tests/test_testing.py
index 26d9298807..e7e0f95535 100644
--- a/trio/_tests/test_testing.py
+++ b/trio/_tests/test_testing.py
@@ -13,15 +13,15 @@
from ..testing._memory_streams import _UnboundedByteQueue
-async def test_wait_all_tasks_blocked():
+async def test_wait_all_tasks_blocked() -> None:
record = []
- async def busy_bee():
+ async def busy_bee() -> None:
for _ in range(10):
await _core.checkpoint()
record.append("busy bee exhausted")
- async def waiting_for_bee_to_leave():
+ async def waiting_for_bee_to_leave() -> None:
await wait_all_tasks_blocked()
record.append("quiet at last!")
@@ -33,7 +33,7 @@ async def waiting_for_bee_to_leave():
# check cancellation
record = []
- async def cancelled_while_waiting():
+ async def cancelled_while_waiting() -> None:
try:
await wait_all_tasks_blocked()
except _core.Cancelled:
@@ -45,10 +45,10 @@ async def cancelled_while_waiting():
assert record == ["ok"]
-async def test_wait_all_tasks_blocked_with_timeouts(mock_clock):
+async def test_wait_all_tasks_blocked_with_timeouts(mock_clock) -> None:
record = []
- async def timeout_task():
+ async def timeout_task() -> None:
record.append("tt start")
await sleep(5)
record.append("tt finished")
@@ -62,25 +62,25 @@ async def timeout_task():
assert record == ["tt start", "tt finished"]
-async def test_wait_all_tasks_blocked_with_cushion():
+async def test_wait_all_tasks_blocked_with_cushion() -> None:
record = []
- async def blink():
+ async def blink() -> None:
record.append("blink start")
await sleep(0.01)
await sleep(0.01)
await sleep(0.01)
record.append("blink end")
- async def wait_no_cushion():
+ async def wait_no_cushion() -> None:
await wait_all_tasks_blocked()
record.append("wait_no_cushion end")
- async def wait_small_cushion():
+ async def wait_small_cushion() -> None:
await wait_all_tasks_blocked(0.02)
record.append("wait_small_cushion end")
- async def wait_big_cushion():
+ async def wait_big_cushion() -> None:
await wait_all_tasks_blocked(0.03)
record.append("wait_big_cushion end")
@@ -104,7 +104,7 @@ async def wait_big_cushion():
################################################################
-async def test_assert_checkpoints(recwarn):
+async def test_assert_checkpoints(recwarn) -> None:
with assert_checkpoints():
await _core.checkpoint()
@@ -130,7 +130,7 @@ async def test_assert_checkpoints(recwarn):
await _core.cancel_shielded_checkpoint()
-async def test_assert_no_checkpoints(recwarn):
+async def test_assert_no_checkpoints(recwarn) -> None:
with assert_no_checkpoints():
1 + 1
@@ -160,14 +160,14 @@ async def test_assert_no_checkpoints(recwarn):
################################################################
-async def test_Sequencer():
+async def test_Sequencer() -> None:
record = []
- def t(val):
+ def t(val) -> None:
print(val)
record.append(val)
- async def f1(seq):
+ async def f1(seq) -> None:
async with seq(1):
t(("f1", 1))
async with seq(3):
@@ -175,7 +175,7 @@ async def f1(seq):
async with seq(4):
t(("f1", 4))
- async def f2(seq):
+ async def f2(seq) -> None:
async with seq(0):
t(("f2", 0))
async with seq(2):
@@ -198,12 +198,12 @@ async def f2(seq):
pass # pragma: no cover
-async def test_Sequencer_cancel():
+async def test_Sequencer_cancel() -> None:
# Killing a blocked task makes everything blow up
record = []
seq = Sequencer()
- async def child(i):
+ async def child(i) -> None:
with _core.CancelScope() as scope:
if i == 1:
scope.cancel()
@@ -245,7 +245,7 @@ async def test__assert_raises():
# This is a private implementation detail, but it's complex enough to be worth
# testing directly
-async def test__UnboundeByteQueue():
+async def test__UnboundeByteQueue() -> None:
ubq = _UnboundedByteQueue()
ubq.put(b"123")
@@ -273,11 +273,11 @@ async def test__UnboundeByteQueue():
with assert_checkpoints():
assert await ubq.get() == b"efghi"
- async def putter(data):
+ async def putter(data) -> None:
await wait_all_tasks_blocked()
ubq.put(data)
- async def getter(expect):
+ async def getter(expect) -> None:
with assert_checkpoints():
assert await ubq.get() == expect
@@ -308,7 +308,7 @@ async def getter(expect):
# close wakes up blocked getters
ubq2 = _UnboundedByteQueue()
- async def closer():
+ async def closer() -> None:
await wait_all_tasks_blocked()
ubq2.close()
@@ -317,10 +317,10 @@ async def closer():
nursery.start_soon(closer)
-async def test_MemorySendStream():
+async def test_MemorySendStream() -> None:
mss = MemorySendStream()
- async def do_send_all(data):
+ async def do_send_all(data) -> None:
with assert_checkpoints():
await mss.send_all(data)
@@ -346,7 +346,7 @@ async def do_send_all(data):
# and we don't know which one will get the error.
resource_busy_count = 0
- async def do_send_all_count_resourcebusy():
+ async def do_send_all_count_resourcebusy() -> None:
nonlocal resource_busy_count
try:
await do_send_all(b"xxx")
@@ -375,15 +375,15 @@ async def do_send_all_count_resourcebusy():
record = []
- async def send_all_hook():
+ async def send_all_hook() -> None:
# hook runs after send_all does its work (can pull data out)
assert mss2.get_data_nowait() == b"abc"
record.append("send_all_hook")
- async def wait_send_all_might_not_block_hook():
+ async def wait_send_all_might_not_block_hook() -> None:
record.append("wait_send_all_might_not_block_hook")
- def close_hook():
+ def close_hook() -> None:
record.append("close_hook")
mss2 = MemorySendStream(
@@ -407,7 +407,7 @@ def close_hook():
]
-async def test_MemoryReceiveStream():
+async def test_MemoryReceiveStream() -> None:
mrs = MemoryReceiveStream()
async def do_receive_some(max_bytes):
@@ -438,12 +438,12 @@ async def do_receive_some(max_bytes):
with pytest.raises(_core.ClosedResourceError):
mrs.put_data(b"---")
- async def receive_some_hook():
+ async def receive_some_hook() -> None:
mrs2.put_data(b"xxx")
record = []
- def close_hook():
+ def close_hook() -> None:
record.append("closed")
mrs2 = MemoryReceiveStream(receive_some_hook, close_hook)
@@ -468,7 +468,7 @@ def close_hook():
await mrs2.receive_some(10)
-async def test_MemoryRecvStream_closing():
+async def test_MemoryRecvStream_closing() -> None:
mrs = MemoryReceiveStream()
# close with no pending data
mrs.close()
@@ -488,7 +488,7 @@ async def test_MemoryRecvStream_closing():
await mrs2.receive_some(10)
-async def test_memory_stream_pump():
+async def test_memory_stream_pump() -> None:
mss = MemorySendStream()
mrs = MemoryReceiveStream()
@@ -512,7 +512,7 @@ async def test_memory_stream_pump():
assert await mrs.receive_some(10) == b""
-async def test_memory_stream_one_way_pair():
+async def test_memory_stream_one_way_pair() -> None:
s, r = memory_stream_one_way_pair()
assert s.send_all_hook is not None
assert s.wait_send_all_might_not_block_hook is None
@@ -521,7 +521,7 @@ async def test_memory_stream_one_way_pair():
await s.send_all(b"123")
assert await r.receive_some(10) == b"123"
- async def receiver(expected):
+ async def receiver(expected) -> None:
assert await r.receive_some(10) == expected
# This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook
@@ -549,11 +549,11 @@ async def receiver(expected):
s.send_all_hook = None
await s.send_all(b"456")
- async def cancel_after_idle(nursery):
+ async def cancel_after_idle(nursery) -> None:
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
- async def check_for_cancel():
+ async def check_for_cancel() -> None:
with pytest.raises(_core.Cancelled):
# This should block forever... or until cancelled. Even though we
# sent some data on the send stream.
@@ -568,7 +568,7 @@ async def check_for_cancel():
assert await r.receive_some(10) == b"456789"
-async def test_memory_stream_pair():
+async def test_memory_stream_pair() -> None:
a, b = memory_stream_pair()
await a.send_all(b"123")
await b.send_all(b"abc")
@@ -578,11 +578,11 @@ async def test_memory_stream_pair():
await a.send_eof()
assert await b.receive_some(10) == b""
- async def sender():
+ async def sender() -> None:
await wait_all_tasks_blocked()
await b.send_all(b"xyz")
- async def receiver():
+ async def receiver() -> None:
assert await a.receive_some(10) == b"xyz"
async with _core.open_nursery() as nursery:
@@ -590,7 +590,7 @@ async def receiver():
nursery.start_soon(sender)
-async def test_memory_streams_with_generic_tests():
+async def test_memory_streams_with_generic_tests() -> None:
async def one_way_stream_maker():
return memory_stream_one_way_pair()
@@ -602,7 +602,7 @@ async def half_closeable_stream_maker():
await check_half_closeable_stream(half_closeable_stream_maker, None)
-async def test_lockstep_streams_with_generic_tests():
+async def test_lockstep_streams_with_generic_tests() -> None:
async def one_way_stream_maker():
return lockstep_stream_one_way_pair()
@@ -614,8 +614,8 @@ async def two_way_stream_maker():
await check_two_way_stream(two_way_stream_maker, two_way_stream_maker)
-async def test_open_stream_to_socket_listener():
- async def check(listener):
+async def test_open_stream_to_socket_listener() -> None:
+ async def check(listener) -> None:
async with listener:
client_stream = await open_stream_to_socket_listener(listener)
async with client_stream:
diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py
index 9e448a4d38..246b50533f 100644
--- a/trio/_tests/test_threads.py
+++ b/trio/_tests/test_threads.py
@@ -27,13 +27,13 @@
from ..testing import wait_all_tasks_blocked
-async def test_do_in_trio_thread():
+async def test_do_in_trio_thread() -> None:
trio_thread = threading.current_thread()
- async def check_case(do_in_trio_thread, fn, expected, trio_token=None):
+ async def check_case(do_in_trio_thread, fn, expected, trio_token=None) -> None:
record = []
- def threadfn():
+ def threadfn() -> None:
try:
record.append(("start", threading.current_thread()))
x = do_in_trio_thread(fn, record, trio_token=trio_token)
@@ -51,7 +51,7 @@ def threadfn():
token = _core.current_trio_token()
- def f(record):
+ def f(record) -> int:
assert not _core.currently_ki_protected()
record.append(("f", threading.current_thread()))
return 2
@@ -65,7 +65,7 @@ def f(record):
await check_case(from_thread_run_sync, f, ("error", ValueError), trio_token=token)
- async def f(record):
+ async def f(record) -> int:
assert not _core.currently_ki_protected()
await _core.checkpoint()
record.append(("f", threading.current_thread()))
@@ -82,26 +82,26 @@ async def f(record):
await check_case(from_thread_run, f, ("error", KeyError), trio_token=token)
-async def test_do_in_trio_thread_from_trio_thread():
+async def test_do_in_trio_thread_from_trio_thread() -> None:
with pytest.raises(RuntimeError):
from_thread_run_sync(lambda: None) # pragma: no branch
- async def foo(): # pragma: no cover
+ async def foo() -> None: # pragma: no cover
pass
with pytest.raises(RuntimeError):
from_thread_run(foo)
-def test_run_in_trio_thread_ki():
+def test_run_in_trio_thread_ki() -> None:
# if we get a control-C during a run_in_trio_thread, then it propagates
# back to the caller (slick!)
record = set()
- async def check_run_in_trio_thread():
+ async def check_run_in_trio_thread() -> None:
token = _core.current_trio_token()
- def trio_thread_fn():
+ def trio_thread_fn() -> None:
print("in Trio thread")
assert not _core.currently_ki_protected()
print("ki_self")
@@ -112,10 +112,10 @@ def trio_thread_fn():
print("finally", sys.exc_info())
- async def trio_thread_afn():
+ async def trio_thread_afn() -> None:
trio_thread_fn()
- def external_thread_fn():
+ def external_thread_fn() -> None:
try:
print("running")
from_thread_run_sync(trio_thread_fn, trio_token=token)
@@ -141,16 +141,16 @@ def external_thread_fn():
assert record == {"ok1", "ok2"}
-def test_await_in_trio_thread_while_main_exits():
+def test_await_in_trio_thread_while_main_exits() -> None:
record = []
ev = Event()
- async def trio_fn():
+ async def trio_fn() -> None:
record.append("sleeping")
ev.set()
await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
- def thread_fn(token):
+ def thread_fn(token) -> None:
try:
from_thread_run(trio_fn, trio_token=token)
except _core.Cancelled:
@@ -169,7 +169,7 @@ async def main():
assert record == ["sleeping", "cancelled"]
-async def test_named_thread():
+async def test_named_thread() -> None:
ending = " from trio._tests.test_threads.test_named_thread"
def inner(name: str = "inner" + ending) -> threading.Thread:
@@ -236,7 +236,7 @@ def _get_thread_name(ident: Optional[int] = None) -> Optional[str]:
# this depends on pthread being available, which is the case on 99.9% of linux machines
# and most mac machines. So unless the platform is linux it will just skip
# in case it fails to fetch the os thread name.
-async def test_named_thread_os():
+async def test_named_thread_os() -> None:
def inner(name: str) -> threading.Thread:
os_thread_name = _get_thread_name()
if os_thread_name is None and sys.platform != "linux":
@@ -271,7 +271,7 @@ async def test_thread_name(name: str, expected: Optional[str] = None) -> None:
await test_thread_name("💙", expected="?")
-async def test_has_pthread_setname_np():
+async def test_has_pthread_setname_np() -> None:
from trio._core._thread_cache import get_os_thread_name_func
k = get_os_thread_name_func()
@@ -280,7 +280,7 @@ async def test_has_pthread_setname_np():
pytest.skip(f"no pthread_setname_np on {sys.platform}")
-async def test_run_in_worker_thread():
+async def test_run_in_worker_thread() -> None:
trio_thread = threading.current_thread()
def f(x):
@@ -299,10 +299,10 @@ def g():
assert excinfo.value.args[0] != trio_thread
-async def test_run_in_worker_thread_cancellation():
+async def test_run_in_worker_thread_cancellation() -> None:
register = [None]
- def f(q):
+ def f(q) -> None:
# Make the thread block for a controlled amount of time
register[0] = "blocking"
q.get()
@@ -359,18 +359,18 @@ async def child(q, cancellable):
# Make sure that if trio.run exits, and then the thread finishes, then that's
# handled gracefully. (Requires that the thread result machinery be prepared
# for call_soon to raise RunFinishedError.)
-def test_run_in_worker_thread_abandoned(capfd, monkeypatch):
+def test_run_in_worker_thread_abandoned(capfd, monkeypatch) -> None:
monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01)
q1 = stdlib_queue.Queue()
q2 = stdlib_queue.Queue()
- def thread_fn():
+ def thread_fn() -> None:
q1.get()
q2.put(threading.current_thread())
- async def main():
- async def child():
+ async def main() -> None:
+ async def child() -> None:
await to_thread_run_sync(thread_fn, cancellable=True)
async with _core.open_nursery() as nursery:
@@ -397,7 +397,7 @@ async def child():
@pytest.mark.parametrize("MAX", [3, 5, 10])
@pytest.mark.parametrize("cancel", [False, True])
@pytest.mark.parametrize("use_default_limiter", [False, True])
-async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter):
+async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter) -> None:
# This test is a bit tricky. The goal is to make sure that if we set
# limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever
# running at a time, even if there are more concurrent calls to
@@ -436,7 +436,7 @@ class state:
token = _core.current_trio_token()
- def thread_fn(cancel_scope):
+ def thread_fn(cancel_scope) -> None:
print("thread_fn start")
from_thread_run_sync(cancel_scope.cancel, trio_token=token)
with lock:
@@ -452,7 +452,7 @@ def thread_fn(cancel_scope):
state.running -= 1
print("thread_fn exiting")
- async def run_thread(event):
+ async def run_thread(event) -> None:
with _core.CancelScope() as cancel_scope:
await to_thread_run_sync(
thread_fn, cancel_scope, limiter=limiter_arg, cancellable=cancel
@@ -501,17 +501,17 @@ async def run_thread(event):
c.total_tokens = orig_total_tokens
-async def test_run_in_worker_thread_custom_limiter():
+async def test_run_in_worker_thread_custom_limiter() -> None:
# Basically just checking that we only call acquire_on_behalf_of and
# release_on_behalf_of, since that's part of our documented API.
record = []
class CustomLimiter:
- async def acquire_on_behalf_of(self, borrower):
+ async def acquire_on_behalf_of(self, borrower) -> None:
record.append("acquire")
self._borrower = borrower
- def release_on_behalf_of(self, borrower):
+ def release_on_behalf_of(self, borrower) -> None:
record.append("release")
assert borrower == self._borrower
@@ -519,11 +519,11 @@ def release_on_behalf_of(self, borrower):
assert record == ["acquire", "release"]
-async def test_run_in_worker_thread_limiter_error():
+async def test_run_in_worker_thread_limiter_error() -> None:
record = []
class BadCapacityLimiter:
- async def acquire_on_behalf_of(self, borrower):
+ async def acquire_on_behalf_of(self, borrower) -> None:
record.append("acquire")
def release_on_behalf_of(self, borrower):
@@ -547,7 +547,7 @@ def release_on_behalf_of(self, borrower):
assert record == ["acquire", "release"]
-async def test_run_in_worker_thread_fail_to_spawn(monkeypatch):
+async def test_run_in_worker_thread_fail_to_spawn(monkeypatch) -> None:
# Test the unlikely but possible case where trying to spawn a thread fails
def bad_start(self, *args):
raise RuntimeError("the engines canna take it captain")
@@ -565,7 +565,7 @@ def bad_start(self, *args):
assert limiter.borrowed_tokens == 0
-async def test_trio_to_thread_run_sync_token():
+async def test_trio_to_thread_run_sync_token() -> None:
# Test that to_thread_run_sync automatically injects the current trio token
# into a spawned thread
def thread_fn():
@@ -577,9 +577,9 @@ def thread_fn():
assert callee_token == caller_token
-async def test_trio_to_thread_run_sync_expected_error():
+async def test_trio_to_thread_run_sync_expected_error() -> None:
# Test correct error when passed async function
- async def async_fn(): # pragma: no cover
+ async def async_fn() -> None: # pragma: no cover
pass
with pytest.raises(TypeError, match="expected a sync function"):
@@ -591,7 +591,7 @@ async def async_fn(): # pragma: no cover
)
-async def test_trio_to_thread_run_sync_contextvars():
+async def test_trio_to_thread_run_sync_contextvars() -> None:
trio_thread = threading.current_thread()
trio_test_contextvar.set("main")
@@ -628,7 +628,7 @@ def g():
assert sniffio.current_async_library() == "trio"
-async def test_trio_from_thread_run_sync():
+async def test_trio_from_thread_run_sync() -> None:
# Test that to_thread_run_sync correctly "hands off" the trio token to
# trio.from_thread.run_sync()
def thread_fn():
@@ -639,26 +639,26 @@ def thread_fn():
assert isinstance(trio_time, float)
# Test correct error when passed async function
- async def async_fn(): # pragma: no cover
+ async def async_fn() -> None: # pragma: no cover
pass
- def thread_fn():
+ def thread_fn() -> None:
from_thread_run_sync(async_fn)
with pytest.raises(TypeError, match="expected a sync function"):
await to_thread_run_sync(thread_fn)
-async def test_trio_from_thread_run():
+async def test_trio_from_thread_run() -> None:
# Test that to_thread_run_sync correctly "hands off" the trio token to
# trio.from_thread.run()
record = []
- async def back_in_trio_fn():
+ async def back_in_trio_fn() -> None:
_core.current_time() # implicitly checks that we're in trio
record.append("back in trio")
- def thread_fn():
+ def thread_fn() -> None:
record.append("in thread")
from_thread_run(back_in_trio_fn)
@@ -666,14 +666,14 @@ def thread_fn():
assert record == ["in thread", "back in trio"]
# Test correct error when passed sync function
- def sync_fn(): # pragma: no cover
+ def sync_fn() -> None: # pragma: no cover
pass
with pytest.raises(TypeError, match="appears to be synchronous"):
await to_thread_run_sync(from_thread_run, sync_fn)
-async def test_trio_from_thread_token():
+async def test_trio_from_thread_token() -> None:
# Test that to_thread_run_sync and spawned trio.from_thread.run_sync()
# share the same Trio token
def thread_fn():
@@ -685,7 +685,7 @@ def thread_fn():
assert callee_token == caller_token
-async def test_trio_from_thread_token_kwarg():
+async def test_trio_from_thread_token_kwarg() -> None:
# Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can
# use an explicitly defined token
def thread_fn(token):
@@ -697,7 +697,7 @@ def thread_fn(token):
assert callee_token == caller_token
-async def test_from_thread_no_token():
+async def test_from_thread_no_token() -> None:
# Test that a "raw call" to trio.from_thread.run() fails because no token
# has been provided
@@ -705,7 +705,7 @@ async def test_from_thread_no_token():
from_thread_run_sync(_core.current_time)
-async def test_trio_from_thread_run_sync_contextvars():
+async def test_trio_from_thread_run_sync_contextvars() -> None:
trio_test_contextvar.set("main")
def thread_fn():
@@ -748,7 +748,7 @@ def back_in_main():
assert back_current_value == "back_in_main"
-async def test_trio_from_thread_run_contextvars():
+async def test_trio_from_thread_run_contextvars() -> None:
trio_test_contextvar.set("main")
def thread_fn():
@@ -791,13 +791,13 @@ async def async_back_in_main():
assert sniffio.current_async_library() == "trio"
-def test_run_fn_as_system_task_catched_badly_typed_token():
+def test_run_fn_as_system_task_catched_badly_typed_token() -> None:
with pytest.raises(RuntimeError):
from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype")
-async def test_from_thread_inside_trio_thread():
- def not_called(): # pragma: no cover
+async def test_from_thread_inside_trio_thread() -> None:
+ def not_called() -> None: # pragma: no cover
assert False
trio_token = _core.current_trio_token()
@@ -806,7 +806,7 @@ def not_called(): # pragma: no cover
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy")
-def test_from_thread_run_during_shutdown():
+def test_from_thread_run_during_shutdown() -> None:
save = []
record = []
@@ -818,7 +818,7 @@ async def agen():
await to_thread_run_sync(from_thread_run, sleep, 0)
record.append("ok")
- async def main():
+ async def main() -> None:
save.append(agen())
await save[-1].asend(None)
@@ -826,18 +826,18 @@ async def main():
assert record == ["ok"]
-async def test_trio_token_weak_referenceable():
+async def test_trio_token_weak_referenceable() -> None:
token = current_trio_token()
assert isinstance(token, TrioToken)
weak_reference = weakref.ref(token)
assert token is weak_reference()
-async def test_unsafe_cancellable_kwarg():
+async def test_unsafe_cancellable_kwarg() -> None:
# This is a stand in for a numpy ndarray or other objects
# that (maybe surprisingly) lack a notion of truthiness
class BadBool:
- def __bool__(self):
+ def __bool__(self) -> bool:
raise NotImplementedError
with pytest.raises(NotImplementedError):
diff --git a/trio/_tests/test_timeouts.py b/trio/_tests/test_timeouts.py
index 9507d88a78..1491067c05 100644
--- a/trio/_tests/test_timeouts.py
+++ b/trio/_tests/test_timeouts.py
@@ -43,13 +43,13 @@ async def check_takes_about(f, expected_dur):
@slow
-async def test_sleep():
- async def sleep_1():
+async def test_sleep() -> None:
+ async def sleep_1() -> None:
await sleep_until(_core.current_time() + TARGET)
await check_takes_about(sleep_1, TARGET)
- async def sleep_2():
+ async def sleep_2() -> None:
await sleep(TARGET)
await check_takes_about(sleep_2, TARGET)
@@ -63,8 +63,8 @@ async def sleep_2():
@slow
-async def test_move_on_after():
- async def sleep_3():
+async def test_move_on_after() -> None:
+ async def sleep_3() -> None:
with move_on_after(TARGET):
await sleep(100)
@@ -72,8 +72,8 @@ async def sleep_3():
@slow
-async def test_fail():
- async def sleep_4():
+async def test_fail() -> None:
+ async def sleep_4() -> None:
with fail_at(_core.current_time() + TARGET):
await sleep(100)
@@ -83,7 +83,7 @@ async def sleep_4():
with fail_at(_core.current_time() + 100):
await sleep(0)
- async def sleep_5():
+ async def sleep_5() -> None:
with fail_after(TARGET):
await sleep(100)
@@ -94,7 +94,7 @@ async def sleep_5():
await sleep(0)
-async def test_timeouts_raise_value_error():
+async def test_timeouts_raise_value_error() -> None:
# deadlines are allowed to be negative, but not delays.
# neither delays nor deadlines are allowed to be NaN
diff --git a/trio/_tests/test_tracing.py b/trio/_tests/test_tracing.py
index e5110eaff3..0cef2b0f44 100644
--- a/trio/_tests/test_tracing.py
+++ b/trio/_tests/test_tracing.py
@@ -25,7 +25,7 @@ async def coro3_async_gen(event: trio.Event) -> None:
pass
-async def test_task_iter_await_frames():
+async def test_task_iter_await_frames() -> None:
async with trio.open_nursery() as nursery:
event = trio.Event()
nursery.start_soon(coro3, event)
@@ -42,7 +42,7 @@ async def test_task_iter_await_frames():
nursery.cancel_scope.cancel()
-async def test_task_iter_await_frames_async_gen():
+async def test_task_iter_await_frames_async_gen() -> None:
async with trio.open_nursery() as nursery:
event = trio.Event()
nursery.start_soon(coro3_async_gen, event)
diff --git a/trio/_tests/test_unix_pipes.py b/trio/_tests/test_unix_pipes.py
index 0b0d2ceb23..f29ee241c6 100644
--- a/trio/_tests/test_unix_pipes.py
+++ b/trio/_tests/test_unix_pipes.py
@@ -57,7 +57,7 @@ async def make_clogged_pipe():
return s, r
-async def test_send_pipe():
+async def test_send_pipe() -> None:
r, w = os.pipe()
async with FdStream(w) as send:
assert send.fileno() == w
@@ -67,7 +67,7 @@ async def test_send_pipe():
os.close(r)
-async def test_receive_pipe():
+async def test_receive_pipe() -> None:
r, w = os.pipe()
async with FdStream(r) as recv:
assert (recv.fileno()) == r
@@ -77,15 +77,15 @@ async def test_receive_pipe():
os.close(w)
-async def test_pipes_combined():
+async def test_pipes_combined() -> None:
write, read = await make_pipe()
count = 2**20
- async def sender():
+ async def sender() -> None:
big = bytearray(count)
await write.send_all(big)
- async def reader():
+ async def reader() -> None:
await wait_all_tasks_blocked()
received = 0
while received < count:
@@ -101,7 +101,7 @@ async def reader():
await write.aclose()
-async def test_pipe_errors():
+async def test_pipe_errors() -> None:
with pytest.raises(TypeError):
FdStream(None)
@@ -112,7 +112,7 @@ async def test_pipe_errors():
await s.receive_some(0)
-async def test_del():
+async def test_del() -> None:
w, r = await make_pipe()
f1, f2 = w.fileno(), r.fileno()
del w, r
@@ -127,7 +127,7 @@ async def test_del():
assert excinfo.value.errno == errno.EBADF
-async def test_async_with():
+async def test_async_with() -> None:
w, r = await make_pipe()
async with w, r:
pass
@@ -144,7 +144,7 @@ async def test_async_with():
assert excinfo.value.errno == errno.EBADF
-async def test_misdirected_aclose_regression():
+async def test_misdirected_aclose_regression() -> None:
# https://github.com/python-trio/trio/issues/661#issuecomment-456582356
w, r = await make_pipe()
old_r_fd = r.fileno()
@@ -164,7 +164,7 @@ async def test_misdirected_aclose_regression():
# And now set up a background task that's working on the new receive
# handle
- async def expect_eof():
+ async def expect_eof() -> None:
assert await r2.receive_some(10) == b""
async with _core.open_nursery() as nursery:
@@ -182,7 +182,7 @@ async def expect_eof():
os.close(w2_fd)
-async def test_close_at_bad_time_for_receive_some(monkeypatch):
+async def test_close_at_bad_time_for_receive_some(monkeypatch) -> None:
# We used to have race conditions where if one task was using the pipe,
# and another closed it at *just* the wrong moment, it would give an
# unexpected error instead of ClosedResourceError:
@@ -190,13 +190,13 @@ async def test_close_at_bad_time_for_receive_some(monkeypatch):
#
# This tests what happens if the pipe gets closed in the moment *between*
# when receive_some wakes up, and when it tries to call os.read
- async def expect_closedresourceerror():
+ async def expect_closedresourceerror() -> None:
with pytest.raises(_core.ClosedResourceError):
await r.receive_some(10)
orig_wait_readable = _core._run.TheIOManager.wait_readable
- async def patched_wait_readable(*args, **kwargs):
+ async def patched_wait_readable(*args, **kwargs) -> None:
await orig_wait_readable(*args, **kwargs)
await r.aclose()
@@ -210,7 +210,7 @@ async def patched_wait_readable(*args, **kwargs):
await s.send_all(b"x")
-async def test_close_at_bad_time_for_send_all(monkeypatch):
+async def test_close_at_bad_time_for_send_all(monkeypatch) -> None:
# We used to have race conditions where if one task was using the pipe,
# and another closed it at *just* the wrong moment, it would give an
# unexpected error instead of ClosedResourceError:
@@ -218,13 +218,13 @@ async def test_close_at_bad_time_for_send_all(monkeypatch):
#
# This tests what happens if the pipe gets closed in the moment *between*
# when send_all wakes up, and when it tries to call os.write
- async def expect_closedresourceerror():
+ async def expect_closedresourceerror() -> None:
with pytest.raises(_core.ClosedResourceError):
await s.send_all(b"x" * 100)
orig_wait_writable = _core._run.TheIOManager.wait_writable
- async def patched_wait_writable(*args, **kwargs):
+ async def patched_wait_writable(*args, **kwargs) -> None:
await orig_wait_writable(*args, **kwargs)
await s.aclose()
@@ -257,7 +257,7 @@ async def patched_wait_writable(*args, **kwargs):
sys.platform.startswith("freebsd"),
reason="no way to make read() return a bizarro error on FreeBSD",
)
-async def test_bizarro_OSError_from_receive():
+async def test_bizarro_OSError_from_receive() -> None:
# Make sure that if the read syscall returns some bizarro error, then we
# get a BrokenResourceError. This is incredibly unlikely; there's almost
# no way to trigger a failure here intentionally (except for EBADF, but we
@@ -277,5 +277,5 @@ async def test_bizarro_OSError_from_receive():
@skip_if_fbsd_pipes_broken
-async def test_pipe_fully():
+async def test_pipe_fully() -> None:
await check_one_way_stream(make_pipe, make_clogged_pipe)
diff --git a/trio/_tests/test_util.py b/trio/_tests/test_util.py
index 9cffaa30d3..ef99e8f66f 100644
--- a/trio/_tests/test_util.py
+++ b/trio/_tests/test_util.py
@@ -24,10 +24,10 @@
from ..testing import wait_all_tasks_blocked
-def test_signal_raise():
+def test_signal_raise() -> None:
record = []
- def handler(signum, _):
+ def handler(signum, _) -> None:
record.append(signum)
old = signal.signal(signal.SIGFPE, handler)
@@ -38,7 +38,7 @@ def handler(signum, _):
assert record == [signal.SIGFPE]
-async def test_ConflictDetector():
+async def test_ConflictDetector() -> None:
ul1 = ConflictDetector("ul1")
ul2 = ConflictDetector("ul2")
@@ -52,7 +52,7 @@ async def test_ConflictDetector():
pass # pragma: no cover
assert "ul1" in str(excinfo.value)
- async def wait_with_ul1():
+ async def wait_with_ul1() -> None:
with ul1:
await wait_all_tasks_blocked()
@@ -63,7 +63,7 @@ async def wait_with_ul1():
assert "ul1" in str(excinfo.value)
-def test_module_metadata_is_fixed_up():
+def test_module_metadata_is_fixed_up() -> None:
import trio
import trio.testing
@@ -87,10 +87,10 @@ def test_module_metadata_is_fixed_up():
assert trio.to_thread.run_sync.__qualname__ == "run_sync"
-async def test_is_main_thread():
+async def test_is_main_thread() -> None:
assert is_main_thread()
- def not_main_thread():
+ def not_main_thread() -> None:
assert not is_main_thread()
await trio.to_thread.run_sync(not_main_thread)
@@ -98,13 +98,13 @@ def not_main_thread():
# @coroutine is deprecated since python 3.8, which is fine with us.
@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning")
-def test_coroutine_or_error():
+def test_coroutine_or_error() -> None:
class Deferred:
"Just kidding"
with ignore_coroutine_never_awaited_warnings():
- async def f(): # pragma: no cover
+ async def f() -> None: # pragma: no cover
pass
with pytest.raises(TypeError) as excinfo:
@@ -156,7 +156,7 @@ async def async_gen(arg): # pragma: no cover
del excinfo
-def test_generic_function():
+def test_generic_function() -> None:
@generic_function
def test_func(arg):
"""Look, a docstring!"""
@@ -187,11 +187,11 @@ class SubClass(FinalClass): # type: ignore[misc]
pass
-def test_no_public_constructor_metaclass():
+def test_no_public_constructor_metaclass() -> None:
"""The NoPublicConstructor metaclass prevents calling the constructor directly."""
class SpecialClass(metaclass=NoPublicConstructor):
- def __init__(self, a: int, b: float):
+ def __init__(self, a: int, b: float) -> None:
"""Check arguments can be passed to __init__."""
assert a == 8
assert b == 3.14
@@ -203,7 +203,7 @@ def __init__(self, a: int, b: float):
assert isinstance(SpecialClass._create(8, b=3.14), SpecialClass)
-def test_fixup_module_metadata():
+def test_fixup_module_metadata() -> None:
# Ignores modules not in the trio.X tree.
non_trio_module = types.ModuleType("not_trio")
non_trio_module.some_func = lambda: None
diff --git a/trio/_tests/test_wait_for_object.py b/trio/_tests/test_wait_for_object.py
index ea16684289..53e771b7ed 100644
--- a/trio/_tests/test_wait_for_object.py
+++ b/trio/_tests/test_wait_for_object.py
@@ -16,7 +16,7 @@
from .._wait_for_object import WaitForMultipleObjects_sync, WaitForSingleObject
-async def test_WaitForMultipleObjects_sync():
+async def test_WaitForMultipleObjects_sync() -> None:
# This does a series of tests where we set/close the handle before
# initiating the waiting for it.
#
@@ -70,7 +70,7 @@ async def test_WaitForMultipleObjects_sync():
@slow
-async def test_WaitForMultipleObjects_sync_slow():
+async def test_WaitForMultipleObjects_sync_slow() -> None:
# This does a series of test in which the main thread sync-waits for
# handles, while we spawn a thread to set the handles after a short while.
@@ -125,7 +125,7 @@ async def test_WaitForMultipleObjects_sync_slow():
print("test_WaitForMultipleObjects_sync_slow thread-set second OK")
-async def test_WaitForSingleObject():
+async def test_WaitForSingleObject() -> None:
# This does a series of test for setting/closing the handle before
# initiating the wait.
@@ -160,7 +160,7 @@ async def test_WaitForSingleObject():
@slow
-async def test_WaitForSingleObject_slow():
+async def test_WaitForSingleObject_slow() -> None:
# This does a series of test for setting the handle in another task,
# and cancelling the wait task.
@@ -168,7 +168,7 @@ async def test_WaitForSingleObject_slow():
# the timeout with a certain margin.
TIMEOUT = 0.3
- async def signal_soon_async(handle):
+ async def signal_soon_async(handle) -> None:
await _timeouts.sleep(TIMEOUT)
kernel32.SetEvent(handle)
diff --git a/trio/_tests/test_windows_pipes.py b/trio/_tests/test_windows_pipes.py
index 5c4bae7d25..399d7116bb 100644
--- a/trio/_tests/test_windows_pipes.py
+++ b/trio/_tests/test_windows_pipes.py
@@ -24,14 +24,14 @@ async def make_pipe() -> Tuple[PipeSendStream, PipeReceiveStream]:
return PipeSendStream(w), PipeReceiveStream(r)
-async def test_pipe_typecheck():
+async def test_pipe_typecheck() -> None:
with pytest.raises(TypeError):
PipeSendStream(1.0)
with pytest.raises(TypeError):
PipeReceiveStream(None)
-async def test_pipe_error_on_close():
+async def test_pipe_error_on_close() -> None:
# Make sure we correctly handle a failure from kernel32.CloseHandle
r, w = pipe()
@@ -47,18 +47,18 @@ async def test_pipe_error_on_close():
await receive_stream.aclose()
-async def test_pipes_combined():
+async def test_pipes_combined() -> None:
write, read = await make_pipe()
count = 2**20
replicas = 3
- async def sender():
+ async def sender() -> None:
async with write:
big = bytearray(count)
for _ in range(replicas):
await write.send_all(big)
- async def reader():
+ async def reader() -> None:
async with read:
await wait_all_tasks_blocked()
total_received = 0
@@ -76,7 +76,7 @@ async def reader():
n.start_soon(reader)
-async def test_async_with():
+async def test_async_with() -> None:
w, r = await make_pipe()
async with w, r:
pass
@@ -87,11 +87,11 @@ async def test_async_with():
await r.receive_some(10)
-async def test_close_during_write():
+async def test_close_during_write() -> None:
w, r = await make_pipe()
async with _core.open_nursery() as nursery:
- async def write_forever():
+ async def write_forever() -> None:
with pytest.raises(_core.ClosedResourceError) as excinfo:
while True:
await w.send_all(b"x" * 4096)
@@ -102,7 +102,7 @@ async def write_forever():
await w.aclose()
-async def test_pipe_fully():
+async def test_pipe_fully() -> None:
# passing make_clogged_pipe tests wait_send_all_might_not_block, and we
# can't implement that on Windows
await check_one_way_stream(make_pipe, None)
diff --git a/trio/_tests/tools/test_gen_exports.py b/trio/_tests/tools/test_gen_exports.py
index b4e23916a0..ee51827a59 100644
--- a/trio/_tests/tools/test_gen_exports.py
+++ b/trio/_tests/tools/test_gen_exports.py
@@ -60,12 +60,12 @@ async def not_public_async(self):
"""
-def test_get_public_methods():
+def test_get_public_methods() -> None:
methods = list(get_public_methods(ast.parse(SOURCE)))
assert {m.name for m in methods} == {"public_func", "public_async_func"}
-def test_create_pass_through_args():
+def test_create_pass_through_args() -> None:
testcases = [
("def f()", "()"),
("def f(one)", "(one)"),
@@ -91,7 +91,7 @@ def test_create_pass_through_args():
@skip_lints
@pytest.mark.parametrize("imports", ["", IMPORT_1, IMPORT_2, IMPORT_3])
-def test_process(tmp_path, imports):
+def test_process(tmp_path, imports) -> None:
try:
import black # noqa: F401
# there's no dedicated CI run that has astor+isort, but lacks black.
From a909b0a3f7a4311b66a7dde578c28f4f84d327db Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sun, 15 Oct 2023 15:22:57 +0200
Subject: [PATCH 03/35] pyannotate --aggressive
---
trio/_tests/test_exports.py | 2 +-
trio/_tests/test_path.py | 6 +++---
trio/_tests/test_scheduler_determinism.py | 2 +-
trio/_tests/test_ssl.py | 2 +-
trio/_tests/test_subprocess.py | 4 ++--
trio/_tests/test_sync.py | 2 +-
6 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py
index 7a016e86d3..6861ba3708 100644
--- a/trio/_tests/test_exports.py
+++ b/trio/_tests/test_exports.py
@@ -126,7 +126,7 @@ def iter_modules(
# https://github.com/pypa/setuptools/issues/3274
"ignore:module 'sre_constants' is deprecated:DeprecationWarning",
)
-def test_static_tool_sees_all_symbols(tool, modname, tmpdir) -> None:
+def test_static_tool_sees_all_symbols(tool, modname: str, tmpdir) -> None:
module = importlib.import_module(modname)
def no_underscores(symbols):
diff --git a/trio/_tests/test_path.py b/trio/_tests/test_path.py
index 1d17029bdf..bb24764c23 100644
--- a/trio/_tests/test_path.py
+++ b/trio/_tests/test_path.py
@@ -14,7 +14,7 @@ def path(tmpdir):
return trio.Path(p)
-def method_pair(path, method_name):
+def method_pair(path, method_name: str):
path = pathlib.Path(path)
async_path = trio.Path(path)
return getattr(path, method_name), getattr(async_path, method_name)
@@ -103,7 +103,7 @@ async def test_async_method_signature(path) -> None:
@pytest.mark.parametrize("method_name", ["is_dir", "is_file"])
-async def test_compare_async_stat_methods(method_name) -> None:
+async def test_compare_async_stat_methods(method_name: str) -> None:
method, async_method = method_pair(".", method_name)
result = method()
@@ -118,7 +118,7 @@ async def test_invalid_name_not_wrapped(path) -> None:
@pytest.mark.parametrize("method_name", ["absolute", "resolve"])
-async def test_async_methods_rewrap(method_name) -> None:
+async def test_async_methods_rewrap(method_name: str) -> None:
method, async_method = method_pair(".", method_name)
result = method()
diff --git a/trio/_tests/test_scheduler_determinism.py b/trio/_tests/test_scheduler_determinism.py
index 1c438f136c..4c0da698ce 100644
--- a/trio/_tests/test_scheduler_determinism.py
+++ b/trio/_tests/test_scheduler_determinism.py
@@ -5,7 +5,7 @@ async def scheduler_trace():
"""Returns a scheduler-dependent value we can use to check determinism."""
trace = []
- async def tracer(name) -> None:
+ async def tracer(name: str) -> None:
for i in range(50):
trace.append((name, i))
await trio.sleep(0)
diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py
index 06ac2e694c..8af210a0ff 100644
--- a/trio/_tests/test_ssl.py
+++ b/trio/_tests/test_ssl.py
@@ -88,7 +88,7 @@ def client_ctx(request):
# The blocking socket server.
-def ssl_echo_serve_sync(sock, *, expect_fail=False):
+def ssl_echo_serve_sync(sock, *, expect_fail: bool = False):
try:
wrapped = SERVER_CTX.wrap_socket(
sock, server_side=True, suppress_ragged_eofs=False
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index 2d805a1814..0f993976e0 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -241,10 +241,10 @@ async def test_interactive(background_process) -> None:
) as proc:
newline = b"\n" if posix else b"\r\n"
- async def expect(idx, request) -> None:
+ async def expect(idx: int, request) -> None:
async with _core.open_nursery() as nursery:
- async def drain_one(stream, count, digit) -> None:
+ async def drain_one(stream, count: int, digit) -> None:
while count > 0:
result = await stream.receive_some(count)
assert result == (f"{digit}".encode() * len(result))
diff --git a/trio/_tests/test_sync.py b/trio/_tests/test_sync.py
index 4325e33a6c..7747740c9f 100644
--- a/trio/_tests/test_sync.py
+++ b/trio/_tests/test_sync.py
@@ -527,7 +527,7 @@ async def test_generic_lock_fifo_fairness(lock_factory) -> None:
record = []
LOOPS = 5
- async def loopy(name, lock_like) -> None:
+ async def loopy(name: str, lock_like) -> None:
# Record the order each task was initially scheduled in
initial_order.append(name)
for _ in range(LOOPS):
From 6cad1937c4c78811cf38755f175e53e2155e2abc Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 17 Oct 2023 12:54:46 +0200
Subject: [PATCH 04/35] WIP before merging origin/master
---
pyproject.toml | 25 ++---
trio/_core/_tests/test_guest_mode.py | 98 ++++++++++---------
trio/_core/_tests/test_instrumentation.py | 10 +-
trio/_core/_tests/test_ki.py | 92 ++++++++---------
trio/_core/_tests/test_mock_clock.py | 22 +++--
trio/_core/_tests/test_multierror.py | 37 +++----
trio/_core/_tests/test_parking_lot.py | 34 ++++---
trio/_core/_tests/test_run.py | 4 +-
trio/_core/_tests/test_thread_cache.py | 36 +++----
trio/_core/_tests/tutil.py | 8 +-
trio/_tests/test_file_io.py | 2 +-
trio/_tests/test_highlevel_generic.py | 10 +-
trio/_tests/test_highlevel_serve_listeners.py | 12 ++-
trio/_tests/test_highlevel_ssl_helpers.py | 14 ++-
trio/_tests/test_path.py | 50 +++++-----
trio/_tests/test_ssl.py | 13 ++-
trio/_tests/test_subprocess.py | 8 +-
trio/_tests/test_sync.py | 16 +--
trio/_tests/test_testing.py | 2 +-
trio/_tests/test_threads.py | 64 +++++++-----
trio/_tests/test_util.py | 43 ++++----
trio/_tests/test_windows_pipes.py | 18 ++--
22 files changed, 332 insertions(+), 286 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index e12f95c309..81e8232965 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -80,10 +80,14 @@ disallow_untyped_defs = true
check_untyped_defs = true
disallow_untyped_calls = false
-
-# partially typed tests
+# files not yet fully typed
[[tool.mypy.overrides]]
module = [
+# internal
+"trio/_windows_pipes",
+
+# tests
+"trio/testing/_fake_net",
"trio/_core/_tests/test_guest_mode",
"trio/_core/_tests/test_ki",
"trio/_core/_tests/test_mock_clock",
@@ -92,22 +96,7 @@ module = [
"trio/_core/_tests/test_multierror_scripts/simple_excepthook",
"trio/_core/_tests/test_parking_lot",
"trio/_core/_tests/test_thread_cache",
-]
-check_untyped_defs = true
-disallow_any_decorated = false
-disallow_any_generics = false
-disallow_any_unimported = false
-disallow_incomplete_defs = false
-disallow_untyped_defs = false
-
-# files not yet fully typed
-[[tool.mypy.overrides]]
-module = [
-# internal
-"trio/_windows_pipes",
-
-# tests
-"trio/testing/_fake_net",
+"trio/_core/_tests/test_unbounded_queue",
"trio/_tests/test_exports",
"trio/_tests/test_file_io",
"trio/_tests/test_highlevel_generic",
diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py
index 21980b9b72..3789cdefa3 100644
--- a/trio/_core/_tests/test_guest_mode.py
+++ b/trio/_core/_tests/test_guest_mode.py
@@ -35,7 +35,7 @@ def trivial_guest_run(trio_fn, *, in_host_after_start=None, **start_guest_run_kw
host_thread = threading.current_thread()
- def run_sync_soon_threadsafe(fn: Callable):
+ def run_sync_soon_threadsafe(fn: Callable) -> None:
nonlocal todo
if host_thread is threading.current_thread(): # pragma: no cover
crash = partial(
@@ -44,7 +44,7 @@ def run_sync_soon_threadsafe(fn: Callable):
todo.put(("run", crash))
todo.put(("run", fn))
- def run_sync_soon_not_threadsafe(fn: Callable):
+ def run_sync_soon_not_threadsafe(fn: Callable) -> None:
nonlocal todo
if host_thread is not threading.current_thread(): # pragma: no cover
crash = partial(
@@ -53,7 +53,7 @@ def run_sync_soon_not_threadsafe(fn: Callable):
todo.put(("run", crash))
todo.put(("run", fn))
- def done_callback(outcome: Outcome):
+ def done_callback(outcome: Outcome) -> None:
nonlocal todo
todo.put(("unwrap", outcome))
@@ -84,8 +84,8 @@ def done_callback(outcome: Outcome):
del todo, run_sync_soon_threadsafe, done_callback
-def test_guest_trivial():
- async def trio_return(in_host):
+def test_guest_trivial() -> None:
+ async def trio_return(in_host) -> str:
await trio.sleep(0)
return "ok"
@@ -98,14 +98,14 @@ async def trio_fail(in_host):
trivial_guest_run(trio_fail)
-def test_guest_can_do_io():
- async def trio_main(in_host):
+def test_guest_can_do_io() -> None:
+ async def trio_main(in_host) -> None:
record = []
a, b = trio.socket.socketpair()
with a, b:
async with trio.open_nursery() as nursery:
- async def do_receive():
+ async def do_receive() -> None:
record.append(await a.recv(1))
nursery.start_soon(do_receive)
@@ -118,17 +118,17 @@ async def do_receive():
trivial_guest_run(trio_main)
-def test_guest_is_initialized_when_start_returns():
+def test_guest_is_initialized_when_start_returns() -> None:
trio_token = None
record = []
- async def trio_main(in_host):
+ async def trio_main(in_host) -> str:
record.append("main task ran")
await trio.sleep(0)
assert trio.lowlevel.current_trio_token() is trio_token
return "ok"
- def after_start():
+ def after_start() -> None:
# We should get control back before the main task executes any code
assert record == []
@@ -137,7 +137,7 @@ def after_start():
trio_token.run_sync_soon(record.append, "run_sync_soon cb ran")
@trio.lowlevel.spawn_system_task
- async def early_task():
+ async def early_task() -> None:
record.append("system task ran")
await trio.sleep(0)
@@ -153,7 +153,7 @@ class BadClock:
def start_clock(self):
raise ValueError("whoops")
- def after_start_never_runs(): # pragma: no cover
+ def after_start_never_runs() -> None: # pragma: no cover
pytest.fail("shouldn't get here")
trivial_guest_run(
@@ -161,8 +161,8 @@ def after_start_never_runs(): # pragma: no cover
)
-def test_host_can_directly_wake_trio_task():
- async def trio_main(in_host):
+def test_host_can_directly_wake_trio_task() -> None:
+ async def trio_main(in_host) -> str:
ev = trio.Event()
in_host(ev.set)
await ev.wait()
@@ -171,11 +171,11 @@ async def trio_main(in_host):
assert trivial_guest_run(trio_main) == "ok"
-def test_host_altering_deadlines_wakes_trio_up():
- def set_deadline(cscope, new_deadline):
+def test_host_altering_deadlines_wakes_trio_up() -> None:
+ def set_deadline(cscope, new_deadline) -> None:
cscope.deadline = new_deadline
- async def trio_main(in_host):
+ async def trio_main(in_host) -> str:
with trio.CancelScope() as cscope:
in_host(lambda: set_deadline(cscope, -inf))
await trio.sleep_forever()
@@ -194,11 +194,11 @@ async def trio_main(in_host):
assert trivial_guest_run(trio_main) == "ok"
-def test_guest_mode_sniffio_integration():
+def test_guest_mode_sniffio_integration() -> None:
from sniffio import current_async_library, thread_local as sniffio_library
- async def trio_main(in_host):
- async def synchronize():
+ async def trio_main(in_host) -> str:
+ async def synchronize() -> None:
"""Wait for all in_host() calls issued so far to complete."""
evt = trio.Event()
in_host(evt.set)
@@ -223,10 +223,10 @@ async def synchronize():
sniffio_library.name = None
-def test_warn_set_wakeup_fd_overwrite():
+def test_warn_set_wakeup_fd_overwrite() -> None:
assert signal.set_wakeup_fd(-1) == -1
- async def trio_main(in_host):
+ async def trio_main(in_host) -> str:
return "ok"
a, b = socket.socketpair()
@@ -268,7 +268,7 @@ async def trio_main(in_host):
signal.set_wakeup_fd(a.fileno())
try:
- async def trio_check_wakeup_fd_unaltered(in_host):
+ async def trio_check_wakeup_fd_unaltered(in_host) -> str:
fd = signal.set_wakeup_fd(-1)
assert fd == a.fileno()
signal.set_wakeup_fd(fd)
@@ -287,19 +287,19 @@ async def trio_check_wakeup_fd_unaltered(in_host):
assert signal.set_wakeup_fd(-1) == a.fileno()
-def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked():
+def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked() -> None:
# This is designed to hit the branch in unrolled_run where:
# idle_primed=True
# runner.runq is empty
# events is Truth-y
# ...and confirm that in this case, wait_all_tasks_blocked does not get
# triggered.
- def set_deadline(cscope, new_deadline):
+ def set_deadline(cscope, new_deadline) -> None:
print(f"setting deadline {new_deadline}")
cscope.deadline = new_deadline
- async def trio_main(in_host):
- async def sit_in_wait_all_tasks_blocked(watb_cscope):
+ async def trio_main(in_host) -> str:
+ async def sit_in_wait_all_tasks_blocked(watb_cscope) -> None:
with watb_cscope:
# Overall point of this test is that this
# wait_all_tasks_blocked should *not* return normally, but
@@ -308,7 +308,7 @@ async def sit_in_wait_all_tasks_blocked(watb_cscope):
assert False # pragma: no cover
assert watb_cscope.cancelled_caught
- async def get_woken_by_host_deadline(watb_cscope):
+ async def get_woken_by_host_deadline(watb_cscope) -> None:
with trio.CancelScope() as cscope:
print("scheduling stuff to happen")
@@ -360,13 +360,13 @@ def after_io_wait(self, timeout: float) -> None:
@restore_unraisablehook()
-def test_guest_warns_if_abandoned():
+def test_guest_warns_if_abandoned() -> None:
# This warning is emitted from the garbage collector. So we have to make
# sure that our abandoned run is garbage. The easiest way to do this is to
# put it into a function, so that we're sure all the local state,
# traceback frames, etc. are garbage once it returns.
- def do_abandoned_guest_run():
- async def abandoned_main(in_host):
+ def do_abandoned_guest_run() -> None:
+ async def abandoned_main(in_host) -> None:
in_host(lambda: 1 / 0)
while True:
await trio.sleep(0)
@@ -401,13 +401,13 @@ async def abandoned_main(in_host):
trio.current_time()
-def aiotrio_run(trio_fn, *, pass_not_threadsafe=True, **start_guest_run_kwargs):
+def aiotrio_run(trio_fn, *, pass_not_threadsafe: bool = True, **start_guest_run_kwargs):
loop = asyncio.new_event_loop()
async def aio_main():
trio_done_fut = loop.create_future()
- def trio_done_callback(main_outcome):
+ def trio_done_callback(main_outcome) -> None:
print(f"trio_fn finished: {main_outcome!r}")
trio_done_fut.set_result(main_outcome)
@@ -429,8 +429,8 @@ def trio_done_callback(main_outcome):
loop.close()
-def test_guest_mode_on_asyncio():
- async def trio_main():
+def test_guest_mode_on_asyncio() -> None:
+ async def trio_main() -> str:
print("trio_main!")
to_trio, from_aio = trio.open_memory_channel[int](float("inf"))
@@ -451,6 +451,8 @@ async def trio_main():
aio_task.cancel()
return "trio-main-done"
+ raise AssertionError("should never be reached")
+
async def aio_pingpong(from_trio, to_trio):
print("aio_pingpong!")
@@ -488,10 +490,10 @@ async def aio_pingpong(from_trio, to_trio):
)
-def test_guest_mode_internal_errors(monkeypatch, recwarn):
+def test_guest_mode_internal_errors(monkeypatch, recwarn) -> None:
with monkeypatch.context() as m:
- async def crash_in_run_loop(in_host):
+ async def crash_in_run_loop(in_host) -> None:
m.setattr("trio._core._run.GLOBAL_RUN_CONTEXT.runner.runq", "HI")
await trio.sleep(1)
@@ -500,7 +502,7 @@ async def crash_in_run_loop(in_host):
with monkeypatch.context() as m:
- async def crash_in_io(in_host):
+ async def crash_in_io(in_host) -> None:
m.setattr("trio._core._run.TheIOManager.get_events", None)
await trio.sleep(0)
@@ -509,7 +511,7 @@ async def crash_in_io(in_host):
with monkeypatch.context() as m:
- async def crash_in_worker_thread_io(in_host):
+ async def crash_in_worker_thread_io(in_host) -> None:
t = threading.current_thread()
old_get_events = trio._core._run.TheIOManager.get_events
@@ -529,11 +531,11 @@ def bad_get_events(*args):
gc_collect_harder()
-def test_guest_mode_ki():
+def test_guest_mode_ki() -> None:
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
# Check SIGINT in Trio func and in host func
- async def trio_main(in_host):
+ async def trio_main(in_host) -> None:
with pytest.raises(KeyboardInterrupt):
signal_raise(signal.SIGINT)
@@ -561,7 +563,7 @@ async def trio_main_raising(in_host):
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
-def test_guest_mode_autojump_clock_threshold_changing():
+def test_guest_mode_autojump_clock_threshold_changing() -> None:
# This is super obscure and probably no-one will ever notice, but
# technically mutating the MockClock.autojump_threshold from the host
# should wake up the guest, so let's test it.
@@ -570,7 +572,7 @@ def test_guest_mode_autojump_clock_threshold_changing():
DURATION = 120
- async def trio_main(in_host):
+ async def trio_main(in_host) -> None:
assert trio.current_time() == 0
in_host(lambda: setattr(clock, "autojump_threshold", 0))
await trio.sleep(DURATION)
@@ -586,12 +588,12 @@ async def trio_main(in_host):
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="PyPy 7.2 is buggy")
@restore_unraisablehook()
-def test_guest_mode_asyncgens():
+def test_guest_mode_asyncgens() -> None:
import sniffio
record = set()
- async def agen(label):
+ async def agen(label: str):
assert sniffio.current_async_library() == label
try:
yield 1
@@ -603,10 +605,10 @@ async def agen(label):
pass
record.add((label, library))
- async def iterate_in_aio():
+ async def iterate_in_aio() -> None:
await agen("asyncio").asend(None)
- async def trio_main():
+ async def trio_main() -> None:
task = asyncio.ensure_future(iterate_in_aio())
done_evt = trio.Event()
task.add_done_callback(lambda _: done_evt.set())
diff --git a/trio/_core/_tests/test_instrumentation.py b/trio/_core/_tests/test_instrumentation.py
index 8b103dbfdc..ecd7585ef2 100644
--- a/trio/_core/_tests/test_instrumentation.py
+++ b/trio/_core/_tests/test_instrumentation.py
@@ -75,9 +75,9 @@ async def main() -> None:
# reschedules the task immediately upon yielding, before the
# after_task_step event fires.
expected = (
- [("before_run",), ("schedule", task)]
+ [("before_run", None), ("schedule", task)]
+ [("before", task), ("schedule", task), ("after", task)] * 5
- + [("before", task), ("after", task), ("after_run",)]
+ + [("before", task), ("after", task), ("after_run", None)]
)
assert r1.record == r2.record + r3.record
assert task is not None
@@ -104,7 +104,7 @@ async def main() -> None:
_core.run(main, instruments=[r])
expected = [
- ("before_run",),
+ ("before_run", None),
("schedule", tasks["t1"]),
("schedule", tasks["t2"]),
{
@@ -121,7 +121,7 @@ async def main() -> None:
("before", tasks["t2"]),
("after", tasks["t2"]),
},
- ("after_run",),
+ ("after_run", None),
]
print(list(r.filter_tasks(tasks.values())))
check_sequence_matches(list(r.filter_tasks(tasks.values())), expected)
@@ -199,7 +199,7 @@ async def main() -> Task:
# the TaskRecorder kept going throughout, even though the BrokenInstrument
# was disabled
assert ("after", main_task) in r.record
- assert ("after_run",) in r.record
+ assert ("after_run", None) in r.record
# And we got a log message
assert caplog.records[0].exc_info is not None
exc_type, exc_value, exc_traceback = caplog.records[0].exc_info
diff --git a/trio/_core/_tests/test_ki.py b/trio/_core/_tests/test_ki.py
index 0fda688194..60a6b17336 100644
--- a/trio/_core/_tests/test_ki.py
+++ b/trio/_core/_tests/test_ki.py
@@ -24,16 +24,16 @@
from ..._core import Abort, RaiseCancelT
-def ki_self():
+def ki_self() -> None:
signal_raise(signal.SIGINT)
-def test_ki_self():
+def test_ki_self() -> None:
with pytest.raises(KeyboardInterrupt):
ki_self()
-async def test_ki_enabled():
+async def test_ki_enabled() -> None:
# Regular tasks aren't KI-protected
assert not _core.currently_ki_protected()
@@ -41,7 +41,7 @@ async def test_ki_enabled():
token = _core.current_trio_token()
record = []
- def check():
+ def check() -> None:
record.append(_core.currently_ki_protected())
token.run_sync_soon(check)
@@ -49,23 +49,23 @@ def check():
assert record == [True]
@_core.enable_ki_protection
- def protected():
+ def protected() -> None:
assert _core.currently_ki_protected()
unprotected()
@_core.disable_ki_protection
- def unprotected():
+ def unprotected() -> None:
assert not _core.currently_ki_protected()
protected()
@_core.enable_ki_protection
- async def aprotected():
+ async def aprotected() -> None:
assert _core.currently_ki_protected()
await aunprotected()
@_core.disable_ki_protection
- async def aunprotected():
+ async def aunprotected() -> None:
assert not _core.currently_ki_protected()
await aprotected()
@@ -102,16 +102,16 @@ def gen_unprotected():
# .throw(), not the actual caller. So child() here would have a caller deep in
# the guts of the run loop, and always be protected, even when it shouldn't
# have been. (Solution: we don't use .throw() anymore.)
-async def test_ki_enabled_after_yield_briefly():
+async def test_ki_enabled_after_yield_briefly() -> None:
@_core.enable_ki_protection
- async def protected():
+ async def protected() -> None:
await child(True)
@_core.disable_ki_protection
- async def unprotected():
+ async def unprotected() -> None:
await child(False)
- async def child(expected):
+ async def child(expected) -> None:
import traceback
traceback.print_stack()
@@ -146,10 +146,10 @@ def protected_manager():
@pytest.mark.skipif(async_generator is None, reason="async_generator not installed")
-async def test_async_generator_agen_protection():
+async def test_async_generator_agen_protection() -> None:
@_core.enable_ki_protection
- @async_generator
- async def agen_protected1():
+ @async_generator # type: ignore[misc] # untyped generator
+ async def agen_protected1() -> None:
assert _core.currently_ki_protected()
try:
await yield_()
@@ -157,8 +157,8 @@ async def agen_protected1():
assert _core.currently_ki_protected()
@_core.disable_ki_protection
- @async_generator
- async def agen_unprotected1():
+ @async_generator # type: ignore[misc] # untyped generator
+ async def agen_unprotected1() -> None:
assert not _core.currently_ki_protected()
try:
await yield_()
@@ -166,18 +166,18 @@ async def agen_unprotected1():
assert not _core.currently_ki_protected()
# Swap the order of the decorators:
- @async_generator
+ @async_generator # type: ignore[misc] # untyped generator
@_core.enable_ki_protection
- async def agen_protected2():
+ async def agen_protected2() -> None:
assert _core.currently_ki_protected()
try:
await yield_()
finally:
assert _core.currently_ki_protected()
- @async_generator
+ @async_generator # type: ignore[misc] # untyped generator
@_core.disable_ki_protection
- async def agen_unprotected2():
+ async def agen_unprotected2() -> None:
assert not _core.currently_ki_protected()
try:
await yield_()
@@ -190,7 +190,7 @@ async def agen_unprotected2():
await _check_agen(agen_unprotected2)
-async def test_native_agen_protection():
+async def test_native_agen_protection() -> None:
# Native async generators
@_core.enable_ki_protection
async def agen_protected():
@@ -230,20 +230,20 @@ async def _check_agen(agen_fn):
# Test the case where there's no magic local anywhere in the call stack
-def test_ki_disabled_out_of_context():
+def test_ki_disabled_out_of_context() -> None:
assert _core.currently_ki_protected()
-def test_ki_disabled_in_del():
+def test_ki_disabled_in_del() -> None:
def nestedfunction():
return _core.currently_ki_protected()
- def __del__():
+ def __del__() -> None:
assert _core.currently_ki_protected()
assert nestedfunction()
@_core.disable_ki_protection
- def outerfunction():
+ def outerfunction() -> None:
assert not _core.currently_ki_protected()
assert not nestedfunction()
__del__()
@@ -253,15 +253,15 @@ def outerfunction():
assert nestedfunction()
-def test_ki_protection_works():
- async def sleeper(name, record):
+def test_ki_protection_works() -> None:
+ async def sleeper(name: str, record) -> None:
try:
while True:
await _core.checkpoint()
except _core.Cancelled:
record.add(name + " ok")
- async def raiser(name, record):
+ async def raiser(name: str, record):
try:
# os.kill runs signal handlers before returning, so we don't need
# to worry that the handler will be delayed
@@ -286,7 +286,7 @@ async def raiser(name, record):
print("check 1")
record_set: set[str] = set()
- async def check_unprotected_kill():
+ async def check_unprotected_kill() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper, "s1", record_set)
nursery.start_soon(sleeper, "s2", record_set)
@@ -301,7 +301,7 @@ async def check_unprotected_kill():
print("check 2")
record_set = set()
- async def check_protected_kill():
+ async def check_protected_kill() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper, "s1", record_set)
nursery.start_soon(sleeper, "s2", record_set)
@@ -316,10 +316,10 @@ async def check_protected_kill():
# error, then kill)
print("check 3")
- async def check_kill_during_shutdown():
+ async def check_kill_during_shutdown() -> None:
token = _core.current_trio_token()
- def kill_during_shutdown():
+ def kill_during_shutdown() -> None:
assert _core.currently_ki_protected()
try:
token.run_sync_soon(kill_during_shutdown)
@@ -340,7 +340,7 @@ class InstrumentOfDeath(Instrument):
def before_run(self) -> None:
ki_self()
- async def main_1():
+ async def main_1() -> None:
await _core.checkpoint()
with pytest.raises(KeyboardInterrupt):
@@ -350,7 +350,7 @@ async def main_1():
print("check 5")
@_core.enable_ki_protection
- async def main_2():
+ async def main_2() -> None:
assert _core.currently_ki_protected()
ki_self()
with pytest.raises(KeyboardInterrupt):
@@ -362,7 +362,7 @@ async def main_2():
print("check 6")
@_core.enable_ki_protection
- async def main_3():
+ async def main_3() -> None:
assert _core.currently_ki_protected()
ki_self()
await _core.cancel_shielded_checkpoint()
@@ -377,7 +377,7 @@ async def main_3():
print("check 7")
@_core.enable_ki_protection
- async def main_4():
+ async def main_4() -> None:
assert _core.currently_ki_protected()
ki_self()
task = _core.current_task()
@@ -396,7 +396,7 @@ def abort(_: RaiseCancelT) -> Abort:
print("check 8")
@_core.enable_ki_protection
- async def main_5():
+ async def main_5() -> None:
assert _core.currently_ki_protected()
ki_self()
task = _core.current_task()
@@ -419,7 +419,7 @@ def abort(raise_cancel: RaiseCancelT) -> Abort:
print("check 9")
@_core.enable_ki_protection
- async def main_6():
+ async def main_6() -> None:
ki_self()
with pytest.raises(KeyboardInterrupt):
@@ -430,7 +430,7 @@ async def main_6():
# restrict_keyboard_interrupt_to_checkpoints=True
record_list = []
- async def main_7():
+ async def main_7() -> None:
# We're not KI protected...
assert not _core.currently_ki_protected()
ki_self()
@@ -454,7 +454,7 @@ async def main_7():
print("check 11")
@_core.enable_ki_protection
- async def main_8():
+ async def main_8() -> None:
assert _core.currently_ki_protected()
with _core.CancelScope() as cancel_scope:
cancel_scope.cancel()
@@ -469,16 +469,16 @@ async def main_8():
_core.run(main_8)
-def test_ki_is_good_neighbor():
+def test_ki_is_good_neighbor() -> None:
# in the unlikely event someone overwrites our signal handler, we leave
# the overwritten one be
try:
orig = signal.getsignal(signal.SIGINT)
- def my_handler(signum, frame): # pragma: no cover
+ def my_handler(signum, frame) -> None: # pragma: no cover
pass
- async def main():
+ async def main() -> None:
signal.signal(signal.SIGINT, my_handler)
_core.run(main)
@@ -490,7 +490,7 @@ async def main():
# Regression test for #461
# don't know if _active not being visible is a problem
-def test_ki_with_broken_threads():
+def test_ki_with_broken_threads() -> None:
thread = threading.main_thread()
# scary!
@@ -502,7 +502,7 @@ def test_ki_with_broken_threads():
del threading._active[thread.ident] # type: ignore[attr-defined]
@_core.enable_ki_protection
- async def inner():
+ async def inner() -> None:
assert signal.getsignal(signal.SIGINT) != signal.default_int_handler
_core.run(inner)
diff --git a/trio/_core/_tests/test_mock_clock.py b/trio/_core/_tests/test_mock_clock.py
index 9c74df3334..4d655cd1a4 100644
--- a/trio/_core/_tests/test_mock_clock.py
+++ b/trio/_core/_tests/test_mock_clock.py
@@ -11,7 +11,7 @@
from .tutil import slow
-def test_mock_clock():
+def test_mock_clock() -> None:
REAL_NOW = 123.0
c = MockClock()
c._real_clock = lambda: REAL_NOW
@@ -55,7 +55,7 @@ def test_mock_clock():
assert c2.current_time() < 10
-async def test_mock_clock_autojump(mock_clock):
+async def test_mock_clock_autojump(mock_clock) -> None:
assert mock_clock.autojump_threshold == inf
mock_clock.autojump_threshold = 0
@@ -95,7 +95,7 @@ async def test_mock_clock_autojump(mock_clock):
await sleep(100000)
-async def test_mock_clock_autojump_interference(mock_clock):
+async def test_mock_clock_autojump_interference(mock_clock) -> None:
mock_clock.autojump_threshold = 0.02
mock_clock2 = MockClock()
@@ -112,7 +112,7 @@ async def test_mock_clock_autojump_interference(mock_clock):
await sleep(100000)
-def test_mock_clock_autojump_preset():
+def test_mock_clock_autojump_preset() -> None:
# Check that we can set the autojump_threshold before the clock is
# actually in use, and it gets picked up
mock_clock = MockClock(autojump_threshold=0.1)
@@ -122,7 +122,7 @@ def test_mock_clock_autojump_preset():
assert time.perf_counter() - real_start < 1
-async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock):
+async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock) -> None:
# Checks that autojump_threshold=0 doesn't interfere with
# calling wait_all_tasks_blocked with the default cushion=0.
@@ -130,11 +130,11 @@ async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock):
record = []
- async def sleeper():
+ async def sleeper() -> None:
await sleep(100)
record.append("yawn")
- async def waiter():
+ async def waiter() -> None:
await wait_all_tasks_blocked()
record.append("waiter woke")
await sleep(1000)
@@ -148,7 +148,9 @@ async def waiter():
@slow
-async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(mock_clock):
+async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(
+ mock_clock,
+) -> None:
# Checks that autojump_threshold=0 doesn't interfere with
# calling wait_all_tasks_blocked with a non-zero cushion.
@@ -156,11 +158,11 @@ async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(mock_clo
record = []
- async def sleeper():
+ async def sleeper() -> None:
await sleep(100)
record.append("yawn")
- async def waiter():
+ async def waiter() -> None:
await wait_all_tasks_blocked(1)
record.append("waiter done")
diff --git a/trio/_core/_tests/test_multierror.py b/trio/_core/_tests/test_multierror.py
index 6d9fd2a568..60ac41ae14 100644
--- a/trio/_core/_tests/test_multierror.py
+++ b/trio/_core/_tests/test_multierror.py
@@ -9,6 +9,7 @@
import warnings
from pathlib import Path
from traceback import extract_tb, print_exception
+from typing import List
import pytest
@@ -38,11 +39,11 @@ async def raise_nothashable(code):
raise NotHashableException(code)
-def raiser1():
+def raiser1() -> None:
raiser1_2()
-def raiser1_2():
+def raiser1_2() -> None:
raiser1_3()
@@ -50,7 +51,7 @@ def raiser1_3():
raise ValueError("raiser1_string")
-def raiser2():
+def raiser2() -> None:
raiser2_2()
@@ -73,7 +74,7 @@ def get_tb(raiser):
return get_exc(raiser).__traceback__
-def test_concat_tb():
+def test_concat_tb() -> None:
tb1 = get_tb(raiser1)
tb2 = get_tb(raiser2)
@@ -98,7 +99,7 @@ def test_concat_tb():
assert extract_tb(get_tb(raiser2)) == entries2
-def test_MultiError():
+def test_MultiError() -> None:
exc1 = get_exc(raiser1)
exc2 = get_exc(raiser2)
@@ -109,12 +110,12 @@ def test_MultiError():
assert "ValueError" in repr(m)
with pytest.raises(TypeError):
- MultiError(object())
+ MultiError(object()) # type: ignore[arg-type]
with pytest.raises(TypeError):
- MultiError([KeyError(), ValueError])
+ MultiError([KeyError(), ValueError]) # type: ignore[list-item]
-def test_MultiErrorOfSingleMultiError():
+def test_MultiErrorOfSingleMultiError() -> None:
# For MultiError([MultiError]), ensure there is no bad recursion by the
# constructor where __init__ is called if __new__ returns a bare MultiError.
exceptions = (KeyError(), ValueError())
@@ -124,7 +125,7 @@ def test_MultiErrorOfSingleMultiError():
assert b.exceptions == exceptions
-async def test_MultiErrorNotHashable():
+async def test_MultiErrorNotHashable() -> None:
exc1 = NotHashableException(42)
exc2 = NotHashableException(4242)
exc3 = ValueError()
@@ -137,7 +138,7 @@ async def test_MultiErrorNotHashable():
nursery.start_soon(raise_nothashable, 4242)
-def test_MultiError_filter_NotHashable():
+def test_MultiError_filter_NotHashable() -> None:
excs = MultiError([NotHashableException(42), ValueError()])
def handle_ValueError(exc):
@@ -173,7 +174,7 @@ def make_tree():
return MultiError([m12, exc3])
-def assert_tree_eq(m1, m2):
+def assert_tree_eq(m1, m2) -> None:
if m1 is None or m2 is None:
assert m1 is m2
return
@@ -240,7 +241,7 @@ def simple_filter(exc):
+ extract_tb(orig.exceptions[0].exceptions[1].__traceback__)
)
- def p(exc):
+ def p(exc) -> None:
print_exception(type(exc), exc, exc.__traceback__)
p(orig)
@@ -275,7 +276,7 @@ def filter_all(exc):
def test_MultiError_catch():
# No exception to catch
- def noop(_):
+ def noop(_) -> None:
pass # pragma: no cover
with pytest.warns(TrioDeprecationWarning), MultiError.catch(noop):
@@ -391,7 +392,7 @@ def simple_filter(exc):
gc.garbage.clear()
-def assert_match_in_seq(pattern_list, string):
+def assert_match_in_seq(pattern_list: List[str], string: str) -> None:
offset = 0
print("looking for pattern matches...")
for pattern in pattern_list:
@@ -402,14 +403,14 @@ def assert_match_in_seq(pattern_list, string):
offset = match.end()
-def test_assert_match_in_seq():
+def test_assert_match_in_seq() -> None:
assert_match_in_seq(["a", "b"], "xx a xx b xx")
assert_match_in_seq(["b", "a"], "xx b xx a xx")
with pytest.raises(AssertionError):
assert_match_in_seq(["a", "b"], "xx b xx a xx")
-def test_base_multierror():
+def test_base_multierror() -> None:
"""
Test that MultiError() with at least one base exception will return a MultiError
object.
@@ -419,7 +420,7 @@ def test_base_multierror():
assert type(exc) is MultiError
-def test_non_base_multierror():
+def test_non_base_multierror() -> None:
"""
Test that MultiError() without base exceptions will return a NonBaseMultiError
object.
@@ -462,7 +463,7 @@ def run_script(name: str) -> subprocess.CompletedProcess[bytes]:
not Path("/usr/lib/python3/dist-packages/apport_python_hook.py").exists(),
reason="need Ubuntu with python3-apport installed",
)
-def test_apport_excepthook_monkeypatch_interaction():
+def test_apport_excepthook_monkeypatch_interaction() -> None:
completed = run_script("apport_excepthook.py")
stdout = completed.stdout.decode("utf-8")
diff --git a/trio/_core/_tests/test_parking_lot.py b/trio/_core/_tests/test_parking_lot.py
index 3f03fdbade..2946a2fb3c 100644
--- a/trio/_core/_tests/test_parking_lot.py
+++ b/trio/_core/_tests/test_parking_lot.py
@@ -1,3 +1,7 @@
+from __future__ import annotations
+
+from typing import TypeVar
+
import pytest
from ... import _core
@@ -5,11 +9,13 @@
from .._parking_lot import ParkingLot
from .tutil import check_sequence_matches
+T = TypeVar("T")
+
-async def test_parking_lot_basic():
+async def test_parking_lot_basic() -> None:
record = []
- async def waiter(i, lot):
+ async def waiter(i, lot) -> None:
record.append(f"sleep {i}")
await lot.park()
record.append(f"wake {i}")
@@ -76,7 +82,9 @@ async def waiter(i, lot):
lot.unpark(count=1.5)
-async def cancellable_waiter(name, lot, scopes, record):
+async def cancellable_waiter(
+ name: T, lot, scopes: dict[T, _core.CancelScope], record: list[str]
+) -> None:
with _core.CancelScope() as scope:
scopes[name] = scope
record.append(f"sleep {name}")
@@ -88,9 +96,9 @@ async def cancellable_waiter(name, lot, scopes, record):
record.append(f"wake {name}")
-async def test_parking_lot_cancel():
- record = []
- scopes = {}
+async def test_parking_lot_cancel() -> None:
+ record: list[str] = []
+ scopes: dict[int, _core.CancelScope] = {}
async with _core.open_nursery() as nursery:
lot = ParkingLot()
@@ -114,14 +122,14 @@ async def test_parking_lot_cancel():
)
-async def test_parking_lot_repark():
- record = []
- scopes = {}
+async def test_parking_lot_repark() -> None:
+ record: list[str] = []
+ scopes: dict[int, _core.CancelScope] = {}
lot1 = ParkingLot()
lot2 = ParkingLot()
with pytest.raises(TypeError):
- lot1.repark([])
+ lot1.repark([]) # type: ignore[arg-type]
async with _core.open_nursery() as nursery:
nursery.start_soon(cancellable_waiter, 1, lot1, scopes, record)
@@ -168,9 +176,9 @@ async def test_parking_lot_repark():
]
-async def test_parking_lot_repark_with_count():
- record = []
- scopes = {}
+async def test_parking_lot_repark_with_count() -> None:
+ record: list[str] = []
+ scopes: dict[int, _core.CancelScope] = {}
lot1 = ParkingLot()
lot2 = ParkingLot()
async with _core.open_nursery() as nursery:
diff --git a/trio/_core/_tests/test_run.py b/trio/_core/_tests/test_run.py
index f67f83a4b8..8af166b9af 100644
--- a/trio/_core/_tests/test_run.py
+++ b/trio/_core/_tests/test_run.py
@@ -1878,7 +1878,7 @@ async def fail() -> NoReturn:
async def test_nursery_stop_async_iteration() -> None:
class it:
- def __init__(self, count: int):
+ def __init__(self, count: int) -> None:
self.count = count
self.val = 0
@@ -1891,7 +1891,7 @@ async def __anext__(self) -> int:
return val
class async_zip:
- def __init__(self, *largs: it):
+ def __init__(self, *largs: it) -> None:
self.nexts = [obj.__anext__ for obj in largs]
async def _accumulate(
diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py
index 3cd79ecd8a..4893ac3cae 100644
--- a/trio/_core/_tests/test_thread_cache.py
+++ b/trio/_core/_tests/test_thread_cache.py
@@ -14,13 +14,13 @@
from .tutil import gc_collect_harder, slow
-def test_thread_cache_basics():
+def test_thread_cache_basics() -> None:
q = Queue[Outcome]()
def fn() -> NoReturn:
raise RuntimeError("hi")
- def deliver(outcome):
+ def deliver(outcome) -> None:
q.put(outcome)
start_thread_soon(fn, deliver)
@@ -30,19 +30,19 @@ def deliver(outcome):
outcome.unwrap()
-def test_thread_cache_deref():
+def test_thread_cache_deref() -> None:
res = [False]
class del_me:
- def __call__(self):
+ def __call__(self) -> int:
return 42
- def __del__(self):
+ def __del__(self) -> None:
res[0] = True
- q = Queue()
+ q = Queue[Outcome]()
- def deliver(outcome):
+ def deliver(outcome) -> None:
q.put(outcome)
start_thread_soon(del_me(), deliver)
@@ -54,7 +54,7 @@ def deliver(outcome):
@slow
-def test_spawning_new_thread_from_deliver_reuses_starting_thread():
+def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None:
# We know that no-one else is using the thread cache, so if we keep
# submitting new jobs the instant the previous one is finished, we should
# keep getting the same thread over and over. This tests both that the
@@ -63,7 +63,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread():
# Make sure there are a few threads running, so if we weren't LIFO then we
# could grab the wrong one.
- q = Queue()
+ q = Queue[Outcome]()
COUNT = 5
for _ in range(COUNT):
start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result))
@@ -73,7 +73,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread():
seen_threads = set()
done = threading.Event()
- def deliver(n, _):
+ def deliver(n, _) -> None:
print(n)
seen_threads.add(threading.current_thread())
if n == 0:
@@ -89,13 +89,13 @@ def deliver(n, _):
@slow
-def test_idle_threads_exit(monkeypatch):
+def test_idle_threads_exit(monkeypatch) -> None:
# Temporarily set the idle timeout to something tiny, to speed up the
# test. (But non-zero, so that the worker loop will at least yield the
# CPU.)
monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001)
- q = Queue()
+ q = Queue[threading.Thread]()
start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread()))
seen_thread = q.get()
# Since the idle timeout is 0, after sleeping for 1 second, the thread
@@ -116,7 +116,7 @@ def _join_started_threads():
assert not thread.is_alive()
-def test_race_between_idle_exit_and_job_assignment(monkeypatch):
+def test_race_between_idle_exit_and_job_assignment(monkeypatch) -> None:
# This is a lock where the first few times you try to acquire it with a
# timeout, it waits until the lock is available and then pretends to time
# out. Using this in our thread cache implementation causes the following
@@ -135,11 +135,11 @@ def test_race_between_idle_exit_and_job_assignment(monkeypatch):
# everything proceeds as normal.
class JankyLock:
- def __init__(self):
+ def __init__(self) -> None:
self._lock = threading.Lock()
self._counter = 3
- def acquire(self, timeout=-1):
+ def acquire(self, timeout=-1) -> bool:
got_it = self._lock.acquire(timeout=timeout)
if timeout == -1:
return True
@@ -152,7 +152,7 @@ def acquire(self, timeout=-1):
else:
return False
- def release(self):
+ def release(self) -> None:
self._lock.release()
monkeypatch.setattr(_thread_cache, "Lock", JankyLock)
@@ -169,10 +169,10 @@ def release(self):
tc.start_thread_soon(lambda: None, lambda _: None)
-def test_raise_in_deliver(capfd):
+def test_raise_in_deliver(capfd) -> None:
seen_threads = set()
- def track_threads():
+ def track_threads() -> None:
seen_threads.add(threading.current_thread())
def deliver(_):
diff --git a/trio/_core/_tests/tutil.py b/trio/_core/_tests/tutil.py
index 070af8ed15..1d49d8b262 100644
--- a/trio/_core/_tests/tutil.py
+++ b/trio/_core/_tests/tutil.py
@@ -9,7 +9,7 @@
import warnings
from collections.abc import Generator, Iterable, Sequence
from contextlib import closing, contextmanager
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, TypeVar
import pytest
@@ -18,6 +18,8 @@
slow = pytest.mark.skipif(not RUN_SLOW, reason="use --run-slow to run slow tests")
+T = TypeVar("T")
+
# PyPy 7.2 was released with a bug that just never called the async
# generator 'firstiter' hook at all. This impacts tests of end-of-run
# finalization (nothing gets added to runner.asyncgens) and tests of
@@ -98,9 +100,7 @@ def restore_unraisablehook() -> Generator[None, None, None]:
# template is like:
# [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 2.2, 3] or [1, 2.2, 2.1, 3]
-def check_sequence_matches(
- seq: Sequence[object], template: Iterable[object | set[object]]
-) -> None:
+def check_sequence_matches(seq: Sequence[T], template: Iterable[T | set[T]]) -> None:
i = 0
for pattern in template:
if not isinstance(pattern, set):
diff --git a/trio/_tests/test_file_io.py b/trio/_tests/test_file_io.py
index 863ebe81b0..da808c5291 100644
--- a/trio/_tests/test_file_io.py
+++ b/trio/_tests/test_file_io.py
@@ -231,7 +231,7 @@ async def test_aclose_cancelled(path) -> None:
async def test_detach_rewraps_asynciobase() -> None:
raw = io.BytesIO()
- buffered = io.BufferedReader(raw)
+ buffered = io.BufferedReader(raw) # type: ignore[arg-type] # ????????????
async_file = trio.wrap_file(buffered)
diff --git a/trio/_tests/test_highlevel_generic.py b/trio/_tests/test_highlevel_generic.py
index 64c5697184..5ccc69151f 100644
--- a/trio/_tests/test_highlevel_generic.py
+++ b/trio/_tests/test_highlevel_generic.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import attr
import pytest
@@ -7,7 +9,7 @@
@attr.s
class RecordSendStream(SendStream):
- record = attr.ib(factory=list)
+ record: list[str | tuple[str, object]] = attr.ib(factory=list)
async def send_all(self, data) -> None:
self.record.append(("send_all", data))
@@ -21,9 +23,9 @@ async def aclose(self) -> None:
@attr.s
class RecordReceiveStream(ReceiveStream):
- record = attr.ib(factory=list)
+ record: list[str | tuple[str, int | None]] = attr.ib(factory=list)
- async def receive_some(self, max_bytes=None) -> None:
+ async def receive_some(self, max_bytes: int | None = None) -> None: # type: ignore[override]
self.record.append(("receive_some", max_bytes))
async def aclose(self) -> None:
@@ -53,7 +55,7 @@ async def test_StapledStream() -> None:
async def fake_send_eof() -> None:
send_stream.record.append("send_eof")
- send_stream.send_eof = fake_send_eof
+ send_stream.send_eof = fake_send_eof # type: ignore[attr-defined]
await stapled.send_eof()
assert send_stream.record == ["send_eof"]
diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py
index beac59f86d..cf3c79f56f 100644
--- a/trio/_tests/test_highlevel_serve_listeners.py
+++ b/trio/_tests/test_highlevel_serve_listeners.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import errno
from functools import partial
@@ -72,7 +74,9 @@ async def do_tests(parent_nursery) -> None:
parent_nursery.cancel_scope.cancel()
async with trio.open_nursery() as nursery:
- l2 = await nursery.start(trio.serve_listeners, handler, listeners)
+ l2: list[MemoryListener] = await nursery.start(
+ trio.serve_listeners, handler, listeners
+ )
assert l2 == listeners
# This is just split into another function because gh-136 isn't
# implemented yet
@@ -92,7 +96,7 @@ async def raise_error():
listener.accept_hook = raise_error
with pytest.raises(type(error)) as excinfo:
- await trio.serve_listeners(None, [listener])
+ await trio.serve_listeners(None, [listener]) # type: ignore[arg-type]
assert excinfo.value is error
@@ -107,7 +111,7 @@ async def raise_EMFILE():
# It retries every 100 ms, so in 950 ms it will retry at 0, 100, ..., 900
# = 10 times total
with trio.move_on_after(0.950):
- await trio.serve_listeners(None, [listener])
+ await trio.serve_listeners(None, [listener]) # type: ignore[arg-type]
assert len(caplog.records) == 10
for record in caplog.records:
@@ -133,7 +137,7 @@ async def connection_watcher(*, task_status=trio.TASK_STATUS_IGNORED):
with pytest.raises(Done):
async with trio.open_nursery() as nursery:
- handler_nursery = await nursery.start(connection_watcher)
+ handler_nursery: trio.Nursery = await nursery.start(connection_watcher)
await nursery.start(
partial(
trio.serve_listeners,
diff --git a/trio/_tests/test_highlevel_ssl_helpers.py b/trio/_tests/test_highlevel_ssl_helpers.py
index 9bec4ae2f5..afb5f30a6b 100644
--- a/trio/_tests/test_highlevel_ssl_helpers.py
+++ b/trio/_tests/test_highlevel_ssl_helpers.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from functools import partial
import attr
@@ -7,11 +9,13 @@
import trio.testing
from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM
+from .._highlevel_socket import SocketListener
from .._highlevel_ssl_helpers import (
open_ssl_over_tcp_listeners,
open_ssl_over_tcp_stream,
serve_ssl_over_tcp,
)
+from .._ssl import SSLListener
# using noqa because linters don't understand how pytest fixtures work.
from .test_ssl import SERVER_CTX, client_ctx # noqa: F401
@@ -48,11 +52,17 @@ async def test_open_ssl_over_tcp_stream_and_everything_else(
client_ctx, # noqa: F811 # linters doesn't understand fixture
) -> None:
async with trio.open_nursery() as nursery:
- (listener,) = await nursery.start(
+ # TODO: the types are *very* funky here, this seems like an error in some signature
+ # unless this is doing stuff we don't want/expect end users to do
+ res: list[SSLListener] = await nursery.start(
partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1")
)
+ (listener,) = res
async with listener:
- sockaddr = listener.transport_listener.socket.getsockname()
+ # listener.transport_listener is of type Listener[Stream]
+ tp_listener: SocketListener = listener.transport_listener # type: ignore[assignment]
+
+ sockaddr = tp_listener.socket.getsockname()
hostname_resolver = FakeHostnameResolver(sockaddr)
trio.socket.set_custom_hostname_resolver(hostname_resolver)
diff --git a/trio/_tests/test_path.py b/trio/_tests/test_path.py
index bb24764c23..bfef1aaf2c 100644
--- a/trio/_tests/test_path.py
+++ b/trio/_tests/test_path.py
@@ -14,20 +14,20 @@ def path(tmpdir):
return trio.Path(p)
-def method_pair(path, method_name: str):
+def method_pair(path, method_name):
path = pathlib.Path(path)
async_path = trio.Path(path)
return getattr(path, method_name), getattr(async_path, method_name)
-async def test_open_is_async_context_manager(path) -> None:
+async def test_open_is_async_context_manager(path):
async with await path.open("w") as f:
assert isinstance(f, AsyncIOWrapper)
assert f.closed
-async def test_magic() -> None:
+async def test_magic():
path = trio.Path("test")
assert str(path) == "test"
@@ -42,7 +42,7 @@ async def test_magic() -> None:
@pytest.mark.parametrize("cls_a,cls_b", cls_pairs)
-async def test_cmp_magic(cls_a, cls_b) -> None:
+async def test_cmp_magic(cls_a, cls_b):
a, b = cls_a(""), cls_b("")
assert a == b
assert not a != b
@@ -69,7 +69,7 @@ async def test_cmp_magic(cls_a, cls_b) -> None:
@pytest.mark.parametrize("cls_a,cls_b", cls_pairs)
-async def test_div_magic(cls_a, cls_b) -> None:
+async def test_div_magic(cls_a, cls_b):
a, b = cls_a("a"), cls_b("b")
result = a / b
@@ -81,19 +81,19 @@ async def test_div_magic(cls_a, cls_b) -> None:
"cls_a,cls_b", [(trio.Path, pathlib.Path), (trio.Path, trio.Path)]
)
@pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"])
-async def test_hash_magic(cls_a, cls_b, path) -> None:
+async def test_hash_magic(cls_a, cls_b, path):
a, b = cls_a(path), cls_b(path)
assert hash(a) == hash(b)
-async def test_forwarded_properties(path) -> None:
+async def test_forwarded_properties(path):
# use `name` as a representative of forwarded properties
assert "name" in dir(path)
assert path.name == "test"
-async def test_async_method_signature(path) -> None:
+async def test_async_method_signature(path):
# use `resolve` as a representative of wrapped methods
assert path.resolve.__name__ == "resolve"
@@ -103,7 +103,7 @@ async def test_async_method_signature(path) -> None:
@pytest.mark.parametrize("method_name", ["is_dir", "is_file"])
-async def test_compare_async_stat_methods(method_name: str) -> None:
+async def test_compare_async_stat_methods(method_name):
method, async_method = method_pair(".", method_name)
result = method()
@@ -112,13 +112,13 @@ async def test_compare_async_stat_methods(method_name: str) -> None:
assert result == async_result
-async def test_invalid_name_not_wrapped(path) -> None:
+async def test_invalid_name_not_wrapped(path):
with pytest.raises(AttributeError):
getattr(path, "invalid_fake_attr")
@pytest.mark.parametrize("method_name", ["absolute", "resolve"])
-async def test_async_methods_rewrap(method_name: str) -> None:
+async def test_async_methods_rewrap(method_name):
method, async_method = method_pair(".", method_name)
result = method()
@@ -128,7 +128,7 @@ async def test_async_methods_rewrap(method_name: str) -> None:
assert str(result) == str(async_result)
-async def test_forward_methods_rewrap(path, tmpdir) -> None:
+async def test_forward_methods_rewrap(path, tmpdir):
with_name = path.with_name("foo")
with_suffix = path.with_suffix(".py")
@@ -138,17 +138,17 @@ async def test_forward_methods_rewrap(path, tmpdir) -> None:
assert with_suffix == tmpdir.join("test.py")
-async def test_forward_properties_rewrap(path) -> None:
+async def test_forward_properties_rewrap(path):
assert isinstance(path.parent, trio.Path)
-async def test_forward_methods_without_rewrap(path, tmpdir) -> None:
+async def test_forward_methods_without_rewrap(path, tmpdir):
path = await path.parent.resolve()
assert path.as_uri().startswith("file:///")
-async def test_repr() -> None:
+async def test_repr():
path = trio.Path(".")
assert repr(path) == "trio.Path('.')"
@@ -164,30 +164,30 @@ class MockWrapper:
_wraps = MockWrapped
-async def test_type_forwards_unsupported() -> None:
+async def test_type_forwards_unsupported():
with pytest.raises(TypeError):
Type.generate_forwards(MockWrapper, {})
-async def test_type_wraps_unsupported() -> None:
+async def test_type_wraps_unsupported():
with pytest.raises(TypeError):
Type.generate_wraps(MockWrapper, {})
-async def test_type_forwards_private() -> None:
+async def test_type_forwards_private():
Type.generate_forwards(MockWrapper, {"unsupported": None})
assert not hasattr(MockWrapper, "_private")
-async def test_type_wraps_private() -> None:
+async def test_type_wraps_private():
Type.generate_wraps(MockWrapper, {"unsupported": None})
assert not hasattr(MockWrapper, "_private")
@pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath])
-async def test_path_wraps_path(path, meth) -> None:
+async def test_path_wraps_path(path, meth):
wrapped = await path.absolute()
result = meth(path, wrapped)
if result is None:
@@ -196,17 +196,17 @@ async def test_path_wraps_path(path, meth) -> None:
assert wrapped == result
-async def test_path_nonpath() -> None:
+async def test_path_nonpath():
with pytest.raises(TypeError):
trio.Path(1)
-async def test_open_file_can_open_path(path) -> None:
+async def test_open_file_can_open_path(path):
async with await trio.open_file(path, "w") as f:
assert f.name == os.fspath(path)
-async def test_globmethods(path) -> None:
+async def test_globmethods(path):
# Populate a directory tree
await path.mkdir()
await (path / "foo").mkdir()
@@ -235,7 +235,7 @@ async def test_globmethods(path) -> None:
assert entries == {"_bar.txt", "bar.txt"}
-async def test_iterdir(path) -> None:
+async def test_iterdir(path):
# Populate a directory
await path.mkdir()
await (path / "foo").mkdir()
@@ -249,7 +249,7 @@ async def test_iterdir(path) -> None:
assert entries == {"bar.txt", "foo"}
-async def test_classmethods() -> None:
+async def test_classmethods():
assert isinstance(await trio.Path.home(), trio.Path)
# pathlib.Path has only two classmethods
diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py
index 8af210a0ff..fadddfdaaa 100644
--- a/trio/_tests/test_ssl.py
+++ b/trio/_tests/test_ssl.py
@@ -21,6 +21,7 @@
import trio
from .. import _core, socket as tsocket
+from .._abc import Stream
from .._core import BrokenResourceError, ClosedResourceError
from .._core._tests.tutil import slow
from .._highlevel_generic import aclose_forcefully
@@ -737,10 +738,20 @@ async def do_wait_send_all_might_not_block() -> None:
async def test_wait_writable_calls_underlying_wait_writable() -> None:
record = []
- class NotAStream:
+ class NotAStream(Stream):
async def wait_send_all_might_not_block(self) -> None:
record.append("ok")
+ # define methods that are abstract in Stream
+ async def aclose(self) -> None:
+ raise AssertionError("Should not get called")
+
+ async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
+ raise AssertionError("Should not get called")
+
+ async def send_all(self, data: bytes | bytearray | memoryview) -> None:
+ raise AssertionError("Should not get called")
+
ctx = ssl.create_default_context()
s = SSLStream(NotAStream(), ctx, server_hostname="x")
await s.wait_send_all_might_not_block()
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index 0f993976e0..f511ff021e 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -153,6 +153,8 @@ async def test_async_with_basics_deprecated(recwarn) -> None:
) as proc:
pass
assert proc.returncode is not None
+ assert proc.stdin is not None
+ assert proc.stdout is not None
with pytest.raises(ClosedResourceError):
await proc.stdin.send_all(b"x")
with pytest.raises(ClosedResourceError):
@@ -418,7 +420,7 @@ async def test_stderr_stdout(background_process) -> None:
async def test_errors() -> None:
with pytest.raises(TypeError) as excinfo:
- await open_process(["ls"], encoding="utf-8")
+ await open_process(["ls"], encoding="utf-8") # type: ignore[call-overload]
assert "unbuffered byte streams" in str(excinfo.value)
assert "the 'encoding' option is not supported" in str(excinfo.value)
@@ -563,7 +565,9 @@ async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch) ->
async def test_run_process_background_fail() -> None:
with pytest.raises(subprocess.CalledProcessError):
async with _core.open_nursery() as nursery:
- proc = await nursery.start(run_process, EXIT_FALSE)
+ proc: subprocess.CompletedProcess[bytes] = await nursery.start(
+ run_process, EXIT_FALSE
+ )
assert proc.returncode == 1
diff --git a/trio/_tests/test_sync.py b/trio/_tests/test_sync.py
index 7747740c9f..91c865e085 100644
--- a/trio/_tests/test_sync.py
+++ b/trio/_tests/test_sync.py
@@ -176,7 +176,7 @@ async def test_CapacityLimiter_memleak_548() -> None:
async def test_Semaphore() -> None:
with pytest.raises(TypeError):
- Semaphore(1.0)
+ Semaphore(1.0) # type: ignore[arg-type]
with pytest.raises(ValueError):
Semaphore(-1)
s = Semaphore(1)
@@ -224,7 +224,7 @@ async def do_acquire(s) -> None:
async def test_Semaphore_bounded() -> None:
with pytest.raises(TypeError):
- Semaphore(1, max_value=1.0)
+ Semaphore(1, max_value=1.0) # type: ignore[arg-type]
with pytest.raises(ValueError):
Semaphore(2, max_value=1)
bs = Semaphore(1, max_value=1)
@@ -317,9 +317,9 @@ async def holder() -> None:
async def test_Condition() -> None:
with pytest.raises(TypeError):
- Condition(Semaphore(1))
+ Condition(Semaphore(1)) # type: ignore[arg-type]
with pytest.raises(TypeError):
- Condition(StrictFIFOLock)
+ Condition(StrictFIFOLock) # type: ignore[arg-type]
l = Lock() # noqa
c = Condition(l)
assert not l.locked()
@@ -407,8 +407,8 @@ async def waiter(i) -> None:
class ChannelLock1(AsyncContextManagerMixin):
- def __init__(self, capacity) -> None:
- self.s, self.r = open_memory_channel(capacity)
+ def __init__(self, capacity: int) -> None:
+ self.s, self.r = open_memory_channel[None](capacity)
for _ in range(capacity - 1):
self.s.send_nowait(None)
@@ -424,7 +424,7 @@ def release(self) -> None:
class ChannelLock2(AsyncContextManagerMixin):
def __init__(self) -> None:
- self.s, self.r = open_memory_channel(10)
+ self.s, self.r = open_memory_channel[None](10)
self.s.send_nowait(None)
def acquire_nowait(self) -> None:
@@ -439,7 +439,7 @@ def release(self) -> None:
class ChannelLock3(AsyncContextManagerMixin):
def __init__(self) -> None:
- self.s, self.r = open_memory_channel(0)
+ self.s, self.r = open_memory_channel[None](0)
# self.acquired is true when one task acquires the lock and
# only becomes false when it's released and no tasks are
# waiting to acquire.
diff --git a/trio/_tests/test_testing.py b/trio/_tests/test_testing.py
index e7e0f95535..fd953e47bf 100644
--- a/trio/_tests/test_testing.py
+++ b/trio/_tests/test_testing.py
@@ -261,7 +261,7 @@ async def test__UnboundeByteQueue() -> None:
ubq.get_nowait()
with pytest.raises(TypeError):
- ubq.put("string")
+ ubq.put("string") # type: ignore[arg-type]
ubq.put(b"abc")
with assert_checkpoints():
diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py
index 246b50533f..860bc977c0 100644
--- a/trio/_tests/test_threads.py
+++ b/trio/_tests/test_threads.py
@@ -8,7 +8,7 @@
import time
import weakref
from functools import partial
-from typing import Callable, Optional
+from typing import TYPE_CHECKING, Callable, Optional
import pytest
import sniffio
@@ -31,7 +31,7 @@ async def test_do_in_trio_thread() -> None:
trio_thread = threading.current_thread()
async def check_case(do_in_trio_thread, fn, expected, trio_token=None) -> None:
- record = []
+ record: list[tuple[str, threading.Thread | type[BaseException]]] = []
def threadfn() -> None:
try:
@@ -51,35 +51,35 @@ def threadfn() -> None:
token = _core.current_trio_token()
- def f(record) -> int:
+ def f1(record) -> int:
assert not _core.currently_ki_protected()
record.append(("f", threading.current_thread()))
return 2
- await check_case(from_thread_run_sync, f, ("got", 2), trio_token=token)
+ await check_case(from_thread_run_sync, f1, ("got", 2), trio_token=token)
- def f(record):
+ def f2(record):
assert not _core.currently_ki_protected()
record.append(("f", threading.current_thread()))
raise ValueError
- await check_case(from_thread_run_sync, f, ("error", ValueError), trio_token=token)
+ await check_case(from_thread_run_sync, f2, ("error", ValueError), trio_token=token)
- async def f(record) -> int:
+ async def f3(record) -> int:
assert not _core.currently_ki_protected()
await _core.checkpoint()
record.append(("f", threading.current_thread()))
return 3
- await check_case(from_thread_run, f, ("got", 3), trio_token=token)
+ await check_case(from_thread_run, f3, ("got", 3), trio_token=token)
- async def f(record):
+ async def f4(record):
assert not _core.currently_ki_protected()
await _core.checkpoint()
record.append(("f", threading.current_thread()))
raise KeyError
- await check_case(from_thread_run, f, ("error", KeyError), trio_token=token)
+ await check_case(from_thread_run, f4, ("error", KeyError), trio_token=token)
async def test_do_in_trio_thread_from_trio_thread() -> None:
@@ -300,7 +300,7 @@ def g():
async def test_run_in_worker_thread_cancellation() -> None:
- register = [None]
+ register: list[str | None] = [None]
def f(q) -> None:
# Make the thread block for a controlled amount of time
@@ -315,8 +315,8 @@ async def child(q, cancellable):
finally:
record.append("exit")
- record = []
- q = stdlib_queue.Queue()
+ record: list[str] = []
+ q = stdlib_queue.Queue[None]()
async with _core.open_nursery() as nursery:
nursery.start_soon(child, q, True)
# Give it a chance to get started. (This is important because
@@ -362,8 +362,8 @@ async def child(q, cancellable):
def test_run_in_worker_thread_abandoned(capfd, monkeypatch) -> None:
monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01)
- q1 = stdlib_queue.Queue()
- q2 = stdlib_queue.Queue()
+ q1 = stdlib_queue.Queue[None]()
+ q2 = stdlib_queue.Queue[threading.Thread]()
def thread_fn() -> None:
q1.get()
@@ -427,6 +427,11 @@ async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter) ->
# Mutating them in-place is OK though (as long as you use proper
# locking etc.).
class state:
+ if TYPE_CHECKING:
+ ran: int
+ high_water: int
+ running: int
+ parked: int
pass
state.ran = 0
@@ -515,7 +520,9 @@ def release_on_behalf_of(self, borrower) -> None:
record.append("release")
assert borrower == self._borrower
- await to_thread_run_sync(lambda: None, limiter=CustomLimiter())
+ # TODO: should CapacityLimiter have an abc or protocol so users can modify it?
+ # because currently it's `final` so writing code like this is not allowed.
+ await to_thread_run_sync(lambda: None, limiter=CustomLimiter()) # type: ignore[arg-type]
assert record == ["acquire", "release"]
@@ -533,16 +540,16 @@ def release_on_behalf_of(self, borrower):
bs = BadCapacityLimiter()
with pytest.raises(ValueError) as excinfo:
- await to_thread_run_sync(lambda: None, limiter=bs)
+ await to_thread_run_sync(lambda: None, limiter=bs) # type: ignore[arg-type]
assert excinfo.value.__context__ is None
assert record == ["acquire", "release"]
record = []
# If the original function raised an error, then the semaphore error
# chains with it
- d = {}
+ d: dict[str, object] = {}
with pytest.raises(ValueError) as excinfo:
- await to_thread_run_sync(lambda: d["x"], limiter=bs)
+ await to_thread_run_sync(lambda: d["x"], limiter=bs) # type: ignore[arg-type]
assert isinstance(excinfo.value.__context__, KeyError)
assert record == ["acquire", "release"]
@@ -583,7 +590,7 @@ async def async_fn() -> None: # pragma: no cover
pass
with pytest.raises(TypeError, match="expected a sync function"):
- await to_thread_run_sync(async_fn)
+ await to_thread_run_sync(async_fn) # type: ignore[unused-coroutine]
trio_test_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar(
@@ -631,22 +638,22 @@ def g():
async def test_trio_from_thread_run_sync() -> None:
# Test that to_thread_run_sync correctly "hands off" the trio token to
# trio.from_thread.run_sync()
- def thread_fn():
+ def thread_fn_1():
trio_time = from_thread_run_sync(_core.current_time)
return trio_time
- trio_time = await to_thread_run_sync(thread_fn)
+ trio_time = await to_thread_run_sync(thread_fn_1)
assert isinstance(trio_time, float)
# Test correct error when passed async function
async def async_fn() -> None: # pragma: no cover
pass
- def thread_fn() -> None:
- from_thread_run_sync(async_fn)
+ def thread_fn_2() -> None:
+ from_thread_run_sync(async_fn) # type: ignore[unused-coroutine]
with pytest.raises(TypeError, match="expected a sync function"):
- await to_thread_run_sync(thread_fn)
+ await to_thread_run_sync(thread_fn_2)
async def test_trio_from_thread_run() -> None:
@@ -793,7 +800,10 @@ async def async_back_in_main():
def test_run_fn_as_system_task_catched_badly_typed_token() -> None:
with pytest.raises(RuntimeError):
- from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype")
+ from_thread_run_sync(
+ _core.current_time,
+ trio_token="Not TrioTokentype", # type: ignore[arg-type]
+ )
async def test_from_thread_inside_trio_thread() -> None:
@@ -841,4 +851,4 @@ def __bool__(self) -> bool:
raise NotImplementedError
with pytest.raises(NotImplementedError):
- await to_thread_run_sync(int, cancellable=BadBool())
+ await to_thread_run_sync(int, cancellable=BadBool()) # type: ignore[arg-type]
diff --git a/trio/_tests/test_util.py b/trio/_tests/test_util.py
index ef99e8f66f..f074a06096 100644
--- a/trio/_tests/test_util.py
+++ b/trio/_tests/test_util.py
@@ -108,7 +108,7 @@ async def f() -> None: # pragma: no cover
pass
with pytest.raises(TypeError) as excinfo:
- coroutine_or_error(f())
+ coroutine_or_error(f()) # type: ignore[arg-type, unused-coroutine]
assert "expecting an async function" in str(excinfo.value)
import asyncio
@@ -120,35 +120,37 @@ def generator_based_coro(): # pragma: no cover
yield from asyncio.sleep(1)
with pytest.raises(TypeError) as excinfo:
- coroutine_or_error(generator_based_coro())
+ coroutine_or_error(generator_based_coro()) # type: ignore[arg-type, unused-coroutine]
assert "asyncio" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
- coroutine_or_error(create_asyncio_future_in_new_loop())
+ coroutine_or_error(create_asyncio_future_in_new_loop()) # type: ignore[arg-type, unused-coroutine]
assert "asyncio" in str(excinfo.value)
+ # does not raise arg-type error
with pytest.raises(TypeError) as excinfo:
- coroutine_or_error(create_asyncio_future_in_new_loop)
+ coroutine_or_error(create_asyncio_future_in_new_loop) # type: ignore[unused-coroutine]
assert "asyncio" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
- coroutine_or_error(Deferred())
+ coroutine_or_error(Deferred()) # type: ignore[arg-type, unused-coroutine]
assert "twisted" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
- coroutine_or_error(lambda: Deferred())
+ coroutine_or_error(lambda: Deferred()) # type: ignore[arg-type, unused-coroutine, return-value]
assert "twisted" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
- coroutine_or_error(len, [[1, 2, 3]])
+ coroutine_or_error(len, [[1, 2, 3]]) # type: ignore[arg-type, unused-coroutine]
assert "appears to be synchronous" in str(excinfo.value)
async def async_gen(arg): # pragma: no cover
yield
+ # does not give arg-type typing error
with pytest.raises(TypeError) as excinfo:
- coroutine_or_error(async_gen, [0])
+ coroutine_or_error(async_gen, [0]) # type: ignore[unused-coroutine]
msg = "expected an async function but got an async generator"
assert msg in str(excinfo.value)
@@ -165,8 +167,8 @@ def test_func(arg):
assert test_func is test_func[int] is test_func[int, str]
assert test_func(42) == test_func[int](42) == 42
assert test_func.__doc__ == "Look, a docstring!"
- assert test_func.__qualname__ == "test_generic_function..test_func"
- assert test_func.__name__ == "test_func"
+ assert test_func.__qualname__ == "test_generic_function..test_func" # type: ignore[attr-defined]
+ assert test_func.__name__ == "test_func" # type: ignore[attr-defined]
assert test_func.__module__ == __name__
@@ -206,7 +208,7 @@ def __init__(self, a: int, b: float) -> None:
def test_fixup_module_metadata() -> None:
# Ignores modules not in the trio.X tree.
non_trio_module = types.ModuleType("not_trio")
- non_trio_module.some_func = lambda: None
+ non_trio_module.some_func = lambda: None # type: ignore[attr-defined]
non_trio_module.some_func.__name__ = "some_func"
non_trio_module.some_func.__qualname__ = "some_func"
@@ -217,26 +219,26 @@ def test_fixup_module_metadata() -> None:
# Bulild up a fake module to test. Just use lambdas since all we care about is the names.
mod = types.ModuleType("trio._somemodule_impl")
- mod.some_func = lambda: None
+ mod.some_func = lambda: None # type: ignore[attr-defined]
mod.some_func.__name__ = "_something_else"
mod.some_func.__qualname__ = "_something_else"
# No __module__ means it's unchanged.
- mod.not_funclike = types.SimpleNamespace()
+ mod.not_funclike = types.SimpleNamespace() # type: ignore[attr-defined]
mod.not_funclike.__name__ = "not_funclike"
# Check __qualname__ being absent works.
- mod.only_has_name = types.SimpleNamespace()
+ mod.only_has_name = types.SimpleNamespace() # type: ignore[attr-defined]
mod.only_has_name.__module__ = "trio._somemodule_impl"
mod.only_has_name.__name__ = "only_name"
# Underscored names are unchanged.
- mod._private = lambda: None
+ mod._private = lambda: None # type: ignore[attr-defined]
mod._private.__module__ = "trio._somemodule_impl"
mod._private.__name__ = mod._private.__qualname__ = "_private"
# We recurse into classes.
- mod.SomeClass = type(
+ mod.SomeClass = type( # type: ignore[attr-defined]
"SomeClass",
(),
{
@@ -244,7 +246,8 @@ def test_fixup_module_metadata() -> None:
"method": lambda self: None,
},
)
- mod.SomeClass.recursion = mod.SomeClass # Reference loop is fine.
+ # Reference loop is fine.
+ mod.SomeClass.recursion = mod.SomeClass # type: ignore[attr-defined]
fixup_module_metadata("trio.somemodule", vars(mod))
assert mod.some_func.__name__ == "some_func"
@@ -260,9 +263,9 @@ def test_fixup_module_metadata() -> None:
assert mod.only_has_name.__module__ == "trio.somemodule"
assert not hasattr(mod.only_has_name, "__qualname__")
- assert mod.SomeClass.method.__name__ == "method"
- assert mod.SomeClass.method.__module__ == "trio.somemodule"
- assert mod.SomeClass.method.__qualname__ == "SomeClass.method"
+ assert mod.SomeClass.method.__name__ == "method" # type: ignore[attr-defined]
+ assert mod.SomeClass.method.__module__ == "trio.somemodule" # type: ignore[attr-defined]
+ assert mod.SomeClass.method.__qualname__ == "SomeClass.method" # type: ignore[attr-defined]
# Make coverage happy.
non_trio_module.some_func()
mod.some_func()
diff --git a/trio/_tests/test_windows_pipes.py b/trio/_tests/test_windows_pipes.py
index 399d7116bb..5c4bae7d25 100644
--- a/trio/_tests/test_windows_pipes.py
+++ b/trio/_tests/test_windows_pipes.py
@@ -24,14 +24,14 @@ async def make_pipe() -> Tuple[PipeSendStream, PipeReceiveStream]:
return PipeSendStream(w), PipeReceiveStream(r)
-async def test_pipe_typecheck() -> None:
+async def test_pipe_typecheck():
with pytest.raises(TypeError):
PipeSendStream(1.0)
with pytest.raises(TypeError):
PipeReceiveStream(None)
-async def test_pipe_error_on_close() -> None:
+async def test_pipe_error_on_close():
# Make sure we correctly handle a failure from kernel32.CloseHandle
r, w = pipe()
@@ -47,18 +47,18 @@ async def test_pipe_error_on_close() -> None:
await receive_stream.aclose()
-async def test_pipes_combined() -> None:
+async def test_pipes_combined():
write, read = await make_pipe()
count = 2**20
replicas = 3
- async def sender() -> None:
+ async def sender():
async with write:
big = bytearray(count)
for _ in range(replicas):
await write.send_all(big)
- async def reader() -> None:
+ async def reader():
async with read:
await wait_all_tasks_blocked()
total_received = 0
@@ -76,7 +76,7 @@ async def reader() -> None:
n.start_soon(reader)
-async def test_async_with() -> None:
+async def test_async_with():
w, r = await make_pipe()
async with w, r:
pass
@@ -87,11 +87,11 @@ async def test_async_with() -> None:
await r.receive_some(10)
-async def test_close_during_write() -> None:
+async def test_close_during_write():
w, r = await make_pipe()
async with _core.open_nursery() as nursery:
- async def write_forever() -> None:
+ async def write_forever():
with pytest.raises(_core.ClosedResourceError) as excinfo:
while True:
await w.send_all(b"x" * 4096)
@@ -102,7 +102,7 @@ async def write_forever() -> None:
await w.aclose()
-async def test_pipe_fully() -> None:
+async def test_pipe_fully():
# passing make_clogged_pipe tests wait_send_all_might_not_block, and we
# can't implement that on Windows
await check_one_way_stream(make_pipe, None)
From 9dd47b3bba6b679dc6b8ed043906f59e346acb66 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 17 Oct 2023 13:42:54 +0200
Subject: [PATCH 05/35] enable check_untyped_defs
---
pyproject.toml | 1 -
trio/_tests/test_highlevel_serve_listeners.py | 2 +-
trio/_tests/test_ssl.py | 4 +++-
trio/_tests/test_subprocess.py | 11 ++++++++---
trio/_tests/test_testing.py | 2 +-
trio/_tests/test_tracing.py | 15 +++++++++++----
6 files changed, 24 insertions(+), 11 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 96e8b008bf..4ef771fcd7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -112,7 +112,6 @@ module = [
"trio/_tests/test_wait_for_object",
"trio/_tests/tools/test_gen_exports",
]
-check_untyped_defs = false
disallow_any_decorated = false
disallow_any_generics = false
disallow_any_unimported = false
diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py
index cf3c79f56f..70b01042b2 100644
--- a/trio/_tests/test_highlevel_serve_listeners.py
+++ b/trio/_tests/test_highlevel_serve_listeners.py
@@ -13,7 +13,7 @@
@attr.s(hash=False, eq=False)
class MemoryListener(trio.abc.Listener):
closed = attr.ib(default=False)
- accepted_streams = attr.ib(factory=list)
+ accepted_streams: list[trio.abc.Stream] = attr.ib(factory=list)
queued_streams = attr.ib(
factory=(lambda: trio.open_memory_channel[trio.StapledStream](1))
)
diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py
index fadddfdaaa..85d7f43069 100644
--- a/trio/_tests/test_ssl.py
+++ b/trio/_tests/test_ssl.py
@@ -167,7 +167,9 @@ async def ssl_echo_server(client_ctx, **kwargs):
# The weird in-memory server ... thing.
# Doesn't inherit from Stream because I left out the methods that we don't
# actually need.
-class PyOpenSSLEchoStream:
+# jakkdl: it seems to implement all the abstract methods (now), so I made it inherit
+# from Stream for the sake of typechecking.
+class PyOpenSSLEchoStream(Stream):
def __init__(self, sleeper=None) -> None:
ctx = SSL.Context(SSL.SSLv23_METHOD)
# TLS 1.3 removes renegotiation support. Which is great for them, but
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index f511ff021e..d6bf47ce82 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -8,7 +8,7 @@
from contextlib import asynccontextmanager
from functools import partial
from pathlib import Path as SyncPath
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, AsyncIterator
import pytest
@@ -80,11 +80,16 @@ async def open_process_then_kill(*args, **kwargs):
await proc.wait()
+# not entirely sure about this annotation
@asynccontextmanager
-async def run_process_in_nursery(*args, **kwargs):
+async def run_process_in_nursery(
+ *args, **kwargs
+) -> AsyncIterator[subprocess.CompletedProcess[bytes]]:
async with _core.open_nursery() as nursery:
kwargs.setdefault("check", False)
- proc = await nursery.start(partial(run_process, *args, **kwargs))
+ proc: subprocess.CompletedProcess[bytes] = await nursery.start(
+ partial(run_process, *args, **kwargs)
+ )
yield proc
nursery.cancel_scope.cancel()
diff --git a/trio/_tests/test_testing.py b/trio/_tests/test_testing.py
index fd953e47bf..08f896b725 100644
--- a/trio/_tests/test_testing.py
+++ b/trio/_tests/test_testing.py
@@ -237,7 +237,7 @@ async def test__assert_raises():
with pytest.raises(TypeError):
with _assert_raises(RuntimeError):
- "foo" + 1
+ "foo" + 1 # type: ignore[operator]
with _assert_raises(RuntimeError):
raise RuntimeError
diff --git a/trio/_tests/test_tracing.py b/trio/_tests/test_tracing.py
index 0cef2b0f44..29fbb7a475 100644
--- a/trio/_tests/test_tracing.py
+++ b/trio/_tests/test_tracing.py
@@ -1,3 +1,5 @@
+from typing import AsyncGenerator
+
import trio
@@ -14,10 +16,15 @@ async def coro3(event: trio.Event) -> None:
await coro2(event)
-async def coro2_async_gen(event):
- yield await trio.lowlevel.checkpoint()
- yield await coro1(event)
- yield await trio.lowlevel.checkpoint()
+async def coro2_async_gen(event) -> AsyncGenerator[None, None]:
+ # mypy does not like `yield await trio.lowlevel.checkpoint()` - but that
+ # should be equivalent to splitting the statement
+ await trio.lowlevel.checkpoint()
+ yield
+ await coro1(event)
+ yield
+ await trio.lowlevel.checkpoint()
+ yield
async def coro3_async_gen(event: trio.Event) -> None:
From 787b41e2b5ea6453c350a1a06bf8774a72267bef Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 17 Oct 2023 15:16:19 +0200
Subject: [PATCH 06/35] enable disallow_any_generics, fully type
highlevel_serve_listers
---
pyproject.toml | 2 -
trio/_core/_tests/test_guest_mode.py | 4 +-
trio/_tests/test_highlevel_serve_listeners.py | 61 +++++++++++++------
3 files changed, 45 insertions(+), 22 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 4ef771fcd7..da22a6d271 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -98,7 +98,6 @@ module = [
"trio/_tests/test_file_io",
"trio/_tests/test_highlevel_generic",
"trio/_tests/test_highlevel_open_unix_stream",
-"trio/_tests/test_highlevel_serve_listeners",
"trio/_tests/test_highlevel_ssl_helpers",
"trio/_tests/test_scheduler_determinism",
"trio/_tests/test_ssl",
@@ -113,7 +112,6 @@ module = [
"trio/_tests/tools/test_gen_exports",
]
disallow_any_decorated = false
-disallow_any_generics = false
disallow_any_unimported = false
disallow_incomplete_defs = false
disallow_untyped_defs = false
diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py
index 3789cdefa3..bb3eb5e01a 100644
--- a/trio/_core/_tests/test_guest_mode.py
+++ b/trio/_core/_tests/test_guest_mode.py
@@ -35,7 +35,7 @@ def trivial_guest_run(trio_fn, *, in_host_after_start=None, **start_guest_run_kw
host_thread = threading.current_thread()
- def run_sync_soon_threadsafe(fn: Callable) -> None:
+ def run_sync_soon_threadsafe(fn: Callable[[], object]) -> None:
nonlocal todo
if host_thread is threading.current_thread(): # pragma: no cover
crash = partial(
@@ -44,7 +44,7 @@ def run_sync_soon_threadsafe(fn: Callable) -> None:
todo.put(("run", crash))
todo.put(("run", fn))
- def run_sync_soon_not_threadsafe(fn: Callable) -> None:
+ def run_sync_soon_not_threadsafe(fn: Callable[[], object]) -> None:
nonlocal todo
if host_thread is not threading.current_thread(): # pragma: no cover
crash = partial(
diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py
index 70b01042b2..2e333068ec 100644
--- a/trio/_tests/test_highlevel_serve_listeners.py
+++ b/trio/_tests/test_highlevel_serve_listeners.py
@@ -1,31 +1,49 @@
from __future__ import annotations
import errno
+import sys
from functools import partial
+from typing import Awaitable, Callable, NoReturn
import attr
import pytest
import trio
-from trio.testing import memory_stream_pair, wait_all_tasks_blocked
+from trio import Nursery, StapledStream, TaskStatus
+from trio._channel import MemoryReceiveChannel, MemorySendChannel
+from trio.abc import Stream
+from trio.testing import (
+ MemoryReceiveStream,
+ MemorySendStream,
+ MockClock,
+ memory_stream_pair,
+ wait_all_tasks_blocked,
+)
+
+if sys.version_info < (3, 11):
+ from exceptiongroup import ExceptionGroup
+
+# types are somewhat tentative - I just bruteforced them until I got something that didn't
+# give errors
+TypeThing = StapledStream[MemorySendStream, MemoryReceiveStream]
@attr.s(hash=False, eq=False)
-class MemoryListener(trio.abc.Listener):
- closed = attr.ib(default=False)
+class MemoryListener(trio.abc.Listener[TypeThing]):
+ closed: bool = attr.ib(default=False)
accepted_streams: list[trio.abc.Stream] = attr.ib(factory=list)
- queued_streams = attr.ib(
- factory=(lambda: trio.open_memory_channel[trio.StapledStream](1))
- )
- accept_hook = attr.ib(default=None)
+ queued_streams: tuple[
+ MemorySendChannel[TypeThing], MemoryReceiveChannel[TypeThing]
+ ] = attr.ib(factory=(lambda: trio.open_memory_channel[TypeThing](1)))
+ accept_hook: Callable[[], Awaitable[object]] | None = attr.ib(default=None)
- async def connect(self):
+ async def connect(self) -> StapledStream[MemorySendStream, MemoryReceiveStream]:
assert not self.closed
client, server = memory_stream_pair()
await self.queued_streams[0].send(server)
return client
- async def accept(self):
+ async def accept(self) -> TypeThing:
await trio.lowlevel.checkpoint()
assert not self.closed
if self.accept_hook is not None:
@@ -49,18 +67,20 @@ def close_hook() -> None:
assert trio.current_effective_deadline() == float("-inf")
record.append("closed")
- async def handler(stream) -> None:
+ async def handler(
+ stream: StapledStream[MemorySendStream, MemoryReceiveStream]
+ ) -> None:
await stream.send_all(b"123")
assert await stream.receive_some(10) == b"456"
stream.send_stream.close_hook = close_hook
stream.receive_stream.close_hook = close_hook
- async def client(listener) -> None:
+ async def client(listener: MemoryListener) -> None:
s = await listener.connect()
assert await s.receive_some(10) == b"123"
await s.send_all(b"456")
- async def do_tests(parent_nursery) -> None:
+ async def do_tests(parent_nursery: Nursery) -> None:
async with trio.open_nursery() as nursery:
for listener in listeners:
for _ in range(3):
@@ -90,7 +110,7 @@ async def test_serve_listeners_accept_unrecognized_error() -> None:
for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]:
listener = MemoryListener()
- async def raise_error():
+ async def raise_error() -> NoReturn:
raise error
listener.accept_hook = raise_error
@@ -100,10 +120,12 @@ async def raise_error():
assert excinfo.value is error
-async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog) -> None:
+async def test_serve_listeners_accept_capacity_error(
+ autojump_clock: MockClock, caplog: pytest.LogCaptureFixture
+) -> None:
listener = MemoryListener()
- async def raise_EMFILE():
+ async def raise_EMFILE() -> NoReturn:
raise OSError(errno.EMFILE, "out of file descriptors")
listener.accept_hook = raise_EMFILE
@@ -116,19 +138,22 @@ async def raise_EMFILE():
assert len(caplog.records) == 10
for record in caplog.records:
assert "retrying" in record.msg
+ assert isinstance(record.exc_info, ExceptionGroup)
assert record.exc_info[1].errno == errno.EMFILE
-async def test_serve_listeners_connection_nursery(autojump_clock) -> None:
+async def test_serve_listeners_connection_nursery(autojump_clock: MockClock) -> None:
listener = MemoryListener()
- async def handler(stream) -> None:
+ async def handler(stream: Stream) -> None:
await trio.sleep(1)
class Done(Exception):
pass
- async def connection_watcher(*, task_status=trio.TASK_STATUS_IGNORED):
+ async def connection_watcher(
+ *, task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED
+ ) -> NoReturn:
async with trio.open_nursery() as nursery:
task_status.started(nursery)
await wait_all_tasks_blocked()
From 2edff671d242c372a9505a16baad582bd0b375b5 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 17 Oct 2023 15:39:11 +0200
Subject: [PATCH 07/35] type a bunch of files with few type errors
---
pyproject.toml | 13 -------------
trio/_core/_tests/test_guest_mode.py | 3 ++-
trio/_core/_tests/test_mock_clock.py | 10 ++++++----
.../test_multierror_scripts/simple_excepthook.py | 4 ++--
trio/_core/_tests/test_parking_lot.py | 4 ++--
trio/_core/_tests/test_thread_cache.py | 5 +++--
trio/_tests/test_highlevel_generic.py | 8 +++++---
trio/_tests/test_highlevel_open_unix_stream.py | 6 +++---
trio/_tests/test_scheduler_determinism.py | 10 ++++++++--
trio/_tests/test_ssl.py | 3 ++-
trio/_tests/test_subprocess.py | 7 +++++--
trio/_tests/test_testing.py | 2 +-
trio/_tests/test_threads.py | 5 +++--
trio/_tests/test_timeouts.py | 11 +++++++++--
trio/_tests/test_tracing.py | 2 +-
trio/_tests/test_unix_pipes.py | 5 +++--
trio/_tests/test_wait_for_object.py | 4 ++--
17 files changed, 57 insertions(+), 45 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index da22a6d271..6060954e38 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -84,35 +84,22 @@ disallow_untyped_calls = false
[[tool.mypy.overrides]]
module = [
# tests
-"trio/testing/_fake_net",
"trio/_core/_tests/test_guest_mode",
"trio/_core/_tests/test_ki",
-"trio/_core/_tests/test_mock_clock",
"trio/_core/_tests/test_multierror",
-"trio/_core/_tests/test_multierror_scripts/ipython_custom_exc",
-"trio/_core/_tests/test_multierror_scripts/simple_excepthook",
-"trio/_core/_tests/test_parking_lot",
"trio/_core/_tests/test_thread_cache",
-"trio/_core/_tests/test_unbounded_queue",
"trio/_tests/test_exports",
"trio/_tests/test_file_io",
-"trio/_tests/test_highlevel_generic",
-"trio/_tests/test_highlevel_open_unix_stream",
"trio/_tests/test_highlevel_ssl_helpers",
-"trio/_tests/test_scheduler_determinism",
"trio/_tests/test_ssl",
"trio/_tests/test_subprocess",
"trio/_tests/test_sync",
"trio/_tests/test_testing",
"trio/_tests/test_threads",
-"trio/_tests/test_timeouts",
-"trio/_tests/test_tracing",
"trio/_tests/test_util",
-"trio/_tests/test_wait_for_object",
"trio/_tests/tools/test_gen_exports",
]
disallow_any_decorated = false
-disallow_any_unimported = false
disallow_incomplete_defs = false
disallow_untyped_defs = false
diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py
index bb3eb5e01a..ea0deae81e 100644
--- a/trio/_core/_tests/test_guest_mode.py
+++ b/trio/_core/_tests/test_guest_mode.py
@@ -16,6 +16,7 @@
import pytest
from outcome import Outcome
+from pytest import MonkeyPatch
import trio
import trio.testing
@@ -490,7 +491,7 @@ async def aio_pingpong(from_trio, to_trio):
)
-def test_guest_mode_internal_errors(monkeypatch, recwarn) -> None:
+def test_guest_mode_internal_errors(monkeypatch: MonkeyPatch, recwarn) -> None:
with monkeypatch.context() as m:
async def crash_in_run_loop(in_host) -> None:
diff --git a/trio/_core/_tests/test_mock_clock.py b/trio/_core/_tests/test_mock_clock.py
index 4d655cd1a4..1a0c8b3444 100644
--- a/trio/_core/_tests/test_mock_clock.py
+++ b/trio/_core/_tests/test_mock_clock.py
@@ -55,7 +55,7 @@ def test_mock_clock() -> None:
assert c2.current_time() < 10
-async def test_mock_clock_autojump(mock_clock) -> None:
+async def test_mock_clock_autojump(mock_clock: MockClock) -> None:
assert mock_clock.autojump_threshold == inf
mock_clock.autojump_threshold = 0
@@ -95,7 +95,7 @@ async def test_mock_clock_autojump(mock_clock) -> None:
await sleep(100000)
-async def test_mock_clock_autojump_interference(mock_clock) -> None:
+async def test_mock_clock_autojump_interference(mock_clock: MockClock) -> None:
mock_clock.autojump_threshold = 0.02
mock_clock2 = MockClock()
@@ -122,7 +122,9 @@ def test_mock_clock_autojump_preset() -> None:
assert time.perf_counter() - real_start < 1
-async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock) -> None:
+async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(
+ mock_clock: MockClock,
+) -> None:
# Checks that autojump_threshold=0 doesn't interfere with
# calling wait_all_tasks_blocked with the default cushion=0.
@@ -149,7 +151,7 @@ async def waiter() -> None:
@slow
async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(
- mock_clock,
+ mock_clock: MockClock,
) -> None:
# Checks that autojump_threshold=0 doesn't interfere with
# calling wait_all_tasks_blocked with a non-zero cushion.
diff --git a/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py b/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py
index 65371107bc..236d34e9ba 100644
--- a/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py
+++ b/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py
@@ -3,14 +3,14 @@
from trio._core._multierror import MultiError # Bypass deprecation warnings
-def exc1_fn():
+def exc1_fn() -> Exception:
try:
raise ValueError
except Exception as exc:
return exc
-def exc2_fn():
+def exc2_fn() -> Exception:
try:
raise KeyError
except Exception as exc:
diff --git a/trio/_core/_tests/test_parking_lot.py b/trio/_core/_tests/test_parking_lot.py
index 2946a2fb3c..74a4704bef 100644
--- a/trio/_core/_tests/test_parking_lot.py
+++ b/trio/_core/_tests/test_parking_lot.py
@@ -15,7 +15,7 @@
async def test_parking_lot_basic() -> None:
record = []
- async def waiter(i, lot) -> None:
+ async def waiter(i: int, lot: ParkingLot) -> None:
record.append(f"sleep {i}")
await lot.park()
record.append(f"wake {i}")
@@ -83,7 +83,7 @@ async def waiter(i, lot) -> None:
async def cancellable_waiter(
- name: T, lot, scopes: dict[T, _core.CancelScope], record: list[str]
+ name: T, lot: ParkingLot, scopes: dict[T, _core.CancelScope], record: list[str]
) -> None:
with _core.CancelScope() as scope:
scopes[name] = scope
diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py
index 4893ac3cae..0241917b91 100644
--- a/trio/_core/_tests/test_thread_cache.py
+++ b/trio/_core/_tests/test_thread_cache.py
@@ -8,6 +8,7 @@
import pytest
from outcome import Outcome
+from pytest import MonkeyPatch
from .. import _thread_cache
from .._thread_cache import ThreadCache, start_thread_soon
@@ -89,7 +90,7 @@ def deliver(n, _) -> None:
@slow
-def test_idle_threads_exit(monkeypatch) -> None:
+def test_idle_threads_exit(monkeypatch: MonkeyPatch) -> None:
# Temporarily set the idle timeout to something tiny, to speed up the
# test. (But non-zero, so that the worker loop will at least yield the
# CPU.)
@@ -116,7 +117,7 @@ def _join_started_threads():
assert not thread.is_alive()
-def test_race_between_idle_exit_and_job_assignment(monkeypatch) -> None:
+def test_race_between_idle_exit_and_job_assignment(monkeypatch: MonkeyPatch) -> None:
# This is a lock where the first few times you try to acquire it with a
# timeout, it waits until the lock is available and then pretends to time
# out. Using this in our thread cache implementation causes the following
diff --git a/trio/_tests/test_highlevel_generic.py b/trio/_tests/test_highlevel_generic.py
index 5ccc69151f..4b2008c08c 100644
--- a/trio/_tests/test_highlevel_generic.py
+++ b/trio/_tests/test_highlevel_generic.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+from typing import NoReturn
+
import attr
import pytest
@@ -11,7 +13,7 @@
class RecordSendStream(SendStream):
record: list[str | tuple[str, object]] = attr.ib(factory=list)
- async def send_all(self, data) -> None:
+ async def send_all(self, data: object) -> None:
self.record.append(("send_all", data))
async def wait_send_all_might_not_block(self) -> None:
@@ -76,12 +78,12 @@ async def test_StapledStream_with_erroring_close() -> None:
# Make sure that if one of the aclose methods errors out, then the other
# one still gets called.
class BrokenSendStream(RecordSendStream):
- async def aclose(self):
+ async def aclose(self) -> NoReturn:
await super().aclose()
raise ValueError
class BrokenReceiveStream(RecordReceiveStream):
- async def aclose(self):
+ async def aclose(self) -> NoReturn:
await super().aclose()
raise ValueError
diff --git a/trio/_tests/test_highlevel_open_unix_stream.py b/trio/_tests/test_highlevel_open_unix_stream.py
index cf32b8c0fc..045820fccf 100644
--- a/trio/_tests/test_highlevel_open_unix_stream.py
+++ b/trio/_tests/test_highlevel_open_unix_stream.py
@@ -11,7 +11,7 @@
pytestmark = pytest.mark.skip("Needs unix socket support")
-def test_close_on_error():
+def test_close_on_error() -> None:
class CloseMe:
closed = False
@@ -29,9 +29,9 @@ def close(self) -> None:
@pytest.mark.parametrize("filename", [4, 4.5])
-async def test_open_with_bad_filename_type(filename) -> None:
+async def test_open_with_bad_filename_type(filename: float) -> None:
with pytest.raises(TypeError):
- await open_unix_socket(filename)
+ await open_unix_socket(filename) # type: ignore[arg-type]
async def test_open_bad_socket() -> None:
diff --git a/trio/_tests/test_scheduler_determinism.py b/trio/_tests/test_scheduler_determinism.py
index 4c0da698ce..7e0a8e98de 100644
--- a/trio/_tests/test_scheduler_determinism.py
+++ b/trio/_tests/test_scheduler_determinism.py
@@ -1,7 +1,11 @@
+from __future__ import annotations
+
+from pytest import MonkeyPatch
+
import trio
-async def scheduler_trace():
+async def scheduler_trace() -> tuple[tuple[str, int], ...]:
"""Returns a scheduler-dependent value we can use to check determinism."""
trace = []
@@ -25,7 +29,9 @@ def test_the_trio_scheduler_is_not_deterministic() -> None:
assert len(set(traces)) == len(traces)
-def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch) -> None:
+def test_the_trio_scheduler_is_deterministic_if_seeded(
+ monkeypatch: MonkeyPatch,
+) -> None:
monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True)
traces = []
for _ in range(10):
diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py
index 85d7f43069..53943cabce 100644
--- a/trio/_tests/test_ssl.py
+++ b/trio/_tests/test_ssl.py
@@ -10,6 +10,7 @@
import pytest
+from trio._core import MockClock
from trio._tests.pytest_plugin import skip_if_optional_else_raise
try:
@@ -579,7 +580,7 @@ async def test_renegotiation_simple(client_ctx) -> None:
@slow
-async def test_renegotiation_randomized(mock_clock, client_ctx) -> None:
+async def test_renegotiation_randomized(mock_clock: MockClock, client_ctx) -> None:
# The only blocking things in this function are our random sleeps, so 0 is
# a good threshold.
mock_clock.autojump_threshold = 0
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index d6bf47ce82..3ebf2589bc 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -11,6 +11,7 @@
from typing import TYPE_CHECKING, AsyncIterator
import pytest
+from pytest import MonkeyPatch
from .. import (
ClosedResourceError,
@@ -538,7 +539,7 @@ async def custom_deliver_cancel(proc) -> None:
assert custom_deliver_cancel_called
-async def test_warn_on_failed_cancel_terminate(monkeypatch) -> None:
+async def test_warn_on_failed_cancel_terminate(monkeypatch: MonkeyPatch) -> None:
original_terminate = Process.terminate
def broken_terminate(self):
@@ -555,7 +556,9 @@ def broken_terminate(self):
@pytest.mark.skipif(not posix, reason="posix only")
-async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch) -> None:
+async def test_warn_on_cancel_SIGKILL_escalation(
+ autojump_clock, monkeypatch: MonkeyPatch
+) -> None:
monkeypatch.setattr(Process, "terminate", lambda *args: None)
with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"):
diff --git a/trio/_tests/test_testing.py b/trio/_tests/test_testing.py
index 08f896b725..7461de058c 100644
--- a/trio/_tests/test_testing.py
+++ b/trio/_tests/test_testing.py
@@ -45,7 +45,7 @@ async def cancelled_while_waiting() -> None:
assert record == ["ok"]
-async def test_wait_all_tasks_blocked_with_timeouts(mock_clock) -> None:
+async def test_wait_all_tasks_blocked_with_timeouts(mock_clock: MockClock) -> None:
record = []
async def timeout_task() -> None:
diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py
index 860bc977c0..0fe5d8dc48 100644
--- a/trio/_tests/test_threads.py
+++ b/trio/_tests/test_threads.py
@@ -12,6 +12,7 @@
import pytest
import sniffio
+from pytest import MonkeyPatch
from trio._core import TrioToken, current_trio_token
@@ -359,7 +360,7 @@ async def child(q, cancellable):
# Make sure that if trio.run exits, and then the thread finishes, then that's
# handled gracefully. (Requires that the thread result machinery be prepared
# for call_soon to raise RunFinishedError.)
-def test_run_in_worker_thread_abandoned(capfd, monkeypatch) -> None:
+def test_run_in_worker_thread_abandoned(capfd, monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01)
q1 = stdlib_queue.Queue[None]()
@@ -554,7 +555,7 @@ def release_on_behalf_of(self, borrower):
assert record == ["acquire", "release"]
-async def test_run_in_worker_thread_fail_to_spawn(monkeypatch) -> None:
+async def test_run_in_worker_thread_fail_to_spawn(monkeypatch: MonkeyPatch) -> None:
# Test the unlikely but possible case where trying to spawn a thread fails
def bad_start(self, *args):
raise RuntimeError("the engines canna take it captain")
diff --git a/trio/_tests/test_timeouts.py b/trio/_tests/test_timeouts.py
index 1491067c05..918e763faa 100644
--- a/trio/_tests/test_timeouts.py
+++ b/trio/_tests/test_timeouts.py
@@ -1,4 +1,5 @@
import time
+from typing import Awaitable, Callable, TypeVar
import outcome
import pytest
@@ -8,8 +9,12 @@
from .._timeouts import *
from ..testing import assert_checkpoints
+T = TypeVar("T")
-async def check_takes_about(f, expected_dur):
+
+async def check_takes_about(
+ f: Callable[[], Awaitable[T]], expected_dur: float
+) -> Awaitable[T]:
start = time.perf_counter()
result = await outcome.acapture(f)
dur = time.perf_counter() - start
@@ -34,7 +39,9 @@ async def check_takes_about(f, expected_dur):
# value above is exactly 128 ULPs below 1.0, which would make sense if it
# started as a 1 ULP error at a different dynamic range.)
assert (1 - 1e-8) <= (dur / expected_dur) < 1.5
- return result.unwrap()
+
+ # outcome is not typed
+ return result.unwrap() # type: ignore[no-any-return]
# How long to (attempt to) sleep for when testing. Smaller numbers make the
diff --git a/trio/_tests/test_tracing.py b/trio/_tests/test_tracing.py
index 29fbb7a475..405831876c 100644
--- a/trio/_tests/test_tracing.py
+++ b/trio/_tests/test_tracing.py
@@ -16,7 +16,7 @@ async def coro3(event: trio.Event) -> None:
await coro2(event)
-async def coro2_async_gen(event) -> AsyncGenerator[None, None]:
+async def coro2_async_gen(event: trio.Event) -> AsyncGenerator[None, None]:
# mypy does not like `yield await trio.lowlevel.checkpoint()` - but that
# should be equivalent to splitting the statement
await trio.lowlevel.checkpoint()
diff --git a/trio/_tests/test_unix_pipes.py b/trio/_tests/test_unix_pipes.py
index f29ee241c6..5705a75faa 100644
--- a/trio/_tests/test_unix_pipes.py
+++ b/trio/_tests/test_unix_pipes.py
@@ -7,6 +7,7 @@
from typing import TYPE_CHECKING
import pytest
+from pytest import MonkeyPatch
from .. import _core
from .._core._tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken
@@ -182,7 +183,7 @@ async def expect_eof() -> None:
os.close(w2_fd)
-async def test_close_at_bad_time_for_receive_some(monkeypatch) -> None:
+async def test_close_at_bad_time_for_receive_some(monkeypatch: MonkeyPatch) -> None:
# We used to have race conditions where if one task was using the pipe,
# and another closed it at *just* the wrong moment, it would give an
# unexpected error instead of ClosedResourceError:
@@ -210,7 +211,7 @@ async def patched_wait_readable(*args, **kwargs) -> None:
await s.send_all(b"x")
-async def test_close_at_bad_time_for_send_all(monkeypatch) -> None:
+async def test_close_at_bad_time_for_send_all(monkeypatch: MonkeyPatch) -> None:
# We used to have race conditions where if one task was using the pipe,
# and another closed it at *just* the wrong moment, it would give an
# unexpected error instead of ClosedResourceError:
diff --git a/trio/_tests/test_wait_for_object.py b/trio/_tests/test_wait_for_object.py
index 53e771b7ed..44790497ed 100644
--- a/trio/_tests/test_wait_for_object.py
+++ b/trio/_tests/test_wait_for_object.py
@@ -12,7 +12,7 @@
from .._core._tests.tutil import slow
if on_windows:
- from .._core._windows_cffi import ffi, kernel32
+ from .._core._windows_cffi import Handle, ffi, kernel32
from .._wait_for_object import WaitForMultipleObjects_sync, WaitForSingleObject
@@ -168,7 +168,7 @@ async def test_WaitForSingleObject_slow() -> None:
# the timeout with a certain margin.
TIMEOUT = 0.3
- async def signal_soon_async(handle) -> None:
+ async def signal_soon_async(handle: Handle) -> None:
await _timeouts.sleep(TIMEOUT)
kernel32.SetEvent(handle)
From 718baa0fd69c89d5251707f1dfc06c99b0a48c3f Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 18 Oct 2023 14:53:55 +0200
Subject: [PATCH 08/35] more typing
---
pyproject.toml | 8 -
trio/_core/_tests/test_guest_mode.py | 92 ++++++----
trio/_core/_tests/test_thread_cache.py | 16 +-
trio/_ssl.py | 11 +-
trio/_tests/test_exports.py | 22 ++-
trio/_tests/test_highlevel_serve_listeners.py | 7 +-
trio/_tests/test_highlevel_ssl_helpers.py | 32 +++-
trio/_tests/test_ssl.py | 168 +++++++++++-------
trio/_tests/test_subprocess.py | 6 +-
trio/_tests/test_sync.py | 35 +++-
trio/_tests/test_testing.py | 51 +++---
trio/_tests/test_threads.py | 4 +-
trio/_tests/test_util.py | 15 +-
trio/_tests/test_wait_for_object.py | 2 +-
trio/_tests/tools/test_gen_exports.py | 9 +-
15 files changed, 297 insertions(+), 181 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 6060954e38..831d577e65 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -87,17 +87,9 @@ module = [
"trio/_core/_tests/test_guest_mode",
"trio/_core/_tests/test_ki",
"trio/_core/_tests/test_multierror",
-"trio/_core/_tests/test_thread_cache",
-"trio/_tests/test_exports",
"trio/_tests/test_file_io",
-"trio/_tests/test_highlevel_ssl_helpers",
-"trio/_tests/test_ssl",
"trio/_tests/test_subprocess",
-"trio/_tests/test_sync",
-"trio/_tests/test_testing",
"trio/_tests/test_threads",
-"trio/_tests/test_util",
-"trio/_tests/tools/test_gen_exports",
]
disallow_any_decorated = false
disallow_incomplete_defs = false
diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py
index ea0deae81e..286e8657a1 100644
--- a/trio/_core/_tests/test_guest_mode.py
+++ b/trio/_core/_tests/test_guest_mode.py
@@ -12,18 +12,23 @@
import warnings
from functools import partial
from math import inf
-from typing import Callable
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, NoReturn, TypeVar
import pytest
from outcome import Outcome
-from pytest import MonkeyPatch
+from pytest import MonkeyPatch, WarningsRecorder
+from typing_extensions import TypeAlias
import trio
import trio.testing
+from trio._channel import MemorySendChannel
from ..._util import signal_raise
from .tutil import buggy_pypy_asyncgens, gc_collect_harder, restore_unraisablehook
+T = TypeVar("T")
+InHost: TypeAlias = Callable[[object], None]
+
# The simplest possible "host" loop.
# Nice features:
@@ -31,7 +36,12 @@
# our main
# - final result is returned
# - any unhandled exceptions cause an immediate crash
-def trivial_guest_run(trio_fn, *, in_host_after_start=None, **start_guest_run_kwargs):
+def trivial_guest_run(
+ trio_fn: Callable[..., Awaitable[T]],
+ *,
+ in_host_after_start: Callable[[], None] | None = None,
+ **start_guest_run_kwargs: Any,
+) -> T:
todo: queue.Queue[tuple[str, Outcome]] = queue.Queue()
host_thread = threading.current_thread()
@@ -75,7 +85,7 @@ def done_callback(outcome: Outcome) -> None:
if op == "run":
obj()
elif op == "unwrap":
- return obj.unwrap()
+ return obj.unwrap() # type: ignore[no-any-return]
else: # pragma: no cover
assert False
finally:
@@ -86,13 +96,13 @@ def done_callback(outcome: Outcome) -> None:
def test_guest_trivial() -> None:
- async def trio_return(in_host) -> str:
+ async def trio_return(in_host: InHost) -> str:
await trio.sleep(0)
return "ok"
assert trivial_guest_run(trio_return) == "ok"
- async def trio_fail(in_host):
+ async def trio_fail(in_host: InHost) -> NoReturn:
raise KeyError("whoopsiedaisy")
with pytest.raises(KeyError, match="whoopsiedaisy"):
@@ -100,7 +110,7 @@ async def trio_fail(in_host):
def test_guest_can_do_io() -> None:
- async def trio_main(in_host) -> None:
+ async def trio_main(in_host: InHost) -> None:
record = []
a, b = trio.socket.socketpair()
with a, b:
@@ -123,7 +133,7 @@ def test_guest_is_initialized_when_start_returns() -> None:
trio_token = None
record = []
- async def trio_main(in_host) -> str:
+ async def trio_main(in_host: InHost) -> str:
record.append("main task ran")
await trio.sleep(0)
assert trio.lowlevel.current_trio_token() is trio_token
@@ -151,7 +161,7 @@ async def early_task() -> None:
with pytest.raises(trio.TrioInternalError):
class BadClock:
- def start_clock(self):
+ def start_clock(self) -> NoReturn:
raise ValueError("whoops")
def after_start_never_runs() -> None: # pragma: no cover
@@ -163,7 +173,7 @@ def after_start_never_runs() -> None: # pragma: no cover
def test_host_can_directly_wake_trio_task() -> None:
- async def trio_main(in_host) -> str:
+ async def trio_main(in_host: InHost) -> str:
ev = trio.Event()
in_host(ev.set)
await ev.wait()
@@ -173,10 +183,10 @@ async def trio_main(in_host) -> str:
def test_host_altering_deadlines_wakes_trio_up() -> None:
- def set_deadline(cscope, new_deadline) -> None:
+ def set_deadline(cscope: trio.CancelScope, new_deadline: float) -> None:
cscope.deadline = new_deadline
- async def trio_main(in_host) -> str:
+ async def trio_main(in_host: InHost) -> str:
with trio.CancelScope() as cscope:
in_host(lambda: set_deadline(cscope, -inf))
await trio.sleep_forever()
@@ -198,7 +208,7 @@ async def trio_main(in_host) -> str:
def test_guest_mode_sniffio_integration() -> None:
from sniffio import current_async_library, thread_local as sniffio_library
- async def trio_main(in_host) -> str:
+ async def trio_main(in_host: InHost) -> str:
async def synchronize() -> None:
"""Wait for all in_host() calls issued so far to complete."""
evt = trio.Event()
@@ -227,7 +237,7 @@ async def synchronize() -> None:
def test_warn_set_wakeup_fd_overwrite() -> None:
assert signal.set_wakeup_fd(-1) == -1
- async def trio_main(in_host) -> str:
+ async def trio_main(in_host: InHost) -> str:
return "ok"
a, b = socket.socketpair()
@@ -269,7 +279,7 @@ async def trio_main(in_host) -> str:
signal.set_wakeup_fd(a.fileno())
try:
- async def trio_check_wakeup_fd_unaltered(in_host) -> str:
+ async def trio_check_wakeup_fd_unaltered(in_host: InHost) -> str:
fd = signal.set_wakeup_fd(-1)
assert fd == a.fileno()
signal.set_wakeup_fd(fd)
@@ -295,12 +305,12 @@ def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked() -> None:
# events is Truth-y
# ...and confirm that in this case, wait_all_tasks_blocked does not get
# triggered.
- def set_deadline(cscope, new_deadline) -> None:
+ def set_deadline(cscope: trio.CancelScope, new_deadline: float) -> None:
print(f"setting deadline {new_deadline}")
cscope.deadline = new_deadline
- async def trio_main(in_host) -> str:
- async def sit_in_wait_all_tasks_blocked(watb_cscope) -> None:
+ async def trio_main(in_host: InHost) -> str:
+ async def sit_in_wait_all_tasks_blocked(watb_cscope: trio.CancelScope) -> None:
with watb_cscope:
# Overall point of this test is that this
# wait_all_tasks_blocked should *not* return normally, but
@@ -309,7 +319,7 @@ async def sit_in_wait_all_tasks_blocked(watb_cscope) -> None:
assert False # pragma: no cover
assert watb_cscope.cancelled_caught
- async def get_woken_by_host_deadline(watb_cscope) -> None:
+ async def get_woken_by_host_deadline(watb_cscope: trio.CancelScope) -> None:
with trio.CancelScope() as cscope:
print("scheduling stuff to happen")
@@ -367,7 +377,7 @@ def test_guest_warns_if_abandoned() -> None:
# put it into a function, so that we're sure all the local state,
# traceback frames, etc. are garbage once it returns.
def do_abandoned_guest_run() -> None:
- async def abandoned_main(in_host) -> None:
+ async def abandoned_main(in_host: InHost) -> None:
in_host(lambda: 1 / 0)
while True:
await trio.sleep(0)
@@ -402,13 +412,18 @@ async def abandoned_main(in_host) -> None:
trio.current_time()
-def aiotrio_run(trio_fn, *, pass_not_threadsafe: bool = True, **start_guest_run_kwargs):
+def aiotrio_run(
+ trio_fn: Callable[..., Awaitable[T]],
+ *,
+ pass_not_threadsafe: bool = True,
+ **start_guest_run_kwargs: Any,
+) -> T:
loop = asyncio.new_event_loop()
- async def aio_main():
+ async def aio_main() -> T:
trio_done_fut = loop.create_future()
- def trio_done_callback(main_outcome) -> None:
+ def trio_done_callback(main_outcome: Outcome) -> None:
print(f"trio_fn finished: {main_outcome!r}")
trio_done_fut.set_result(main_outcome)
@@ -422,7 +437,7 @@ def trio_done_callback(main_outcome) -> None:
**start_guest_run_kwargs,
)
- return (await trio_done_fut).unwrap()
+ return (await trio_done_fut).unwrap() # type: ignore[no-any-return]
try:
return loop.run_until_complete(aio_main())
@@ -454,7 +469,9 @@ async def trio_main() -> str:
raise AssertionError("should never be reached")
- async def aio_pingpong(from_trio, to_trio):
+ async def aio_pingpong(
+ from_trio: asyncio.Queue[int], to_trio: MemorySendChannel[int]
+ ) -> None:
print("aio_pingpong!")
try:
@@ -491,10 +508,12 @@ async def aio_pingpong(from_trio, to_trio):
)
-def test_guest_mode_internal_errors(monkeypatch: MonkeyPatch, recwarn) -> None:
+def test_guest_mode_internal_errors(
+ monkeypatch: MonkeyPatch, recwarn: WarningsRecorder
+) -> None:
with monkeypatch.context() as m:
- async def crash_in_run_loop(in_host) -> None:
+ async def crash_in_run_loop(in_host: InHost) -> None:
m.setattr("trio._core._run.GLOBAL_RUN_CONTEXT.runner.runq", "HI")
await trio.sleep(1)
@@ -503,7 +522,7 @@ async def crash_in_run_loop(in_host) -> None:
with monkeypatch.context() as m:
- async def crash_in_io(in_host) -> None:
+ async def crash_in_io(in_host: InHost) -> None:
m.setattr("trio._core._run.TheIOManager.get_events", None)
await trio.sleep(0)
@@ -512,11 +531,11 @@ async def crash_in_io(in_host) -> None:
with monkeypatch.context() as m:
- async def crash_in_worker_thread_io(in_host) -> None:
+ async def crash_in_worker_thread_io(in_host: InHost) -> None:
t = threading.current_thread()
old_get_events = trio._core._run.TheIOManager.get_events
- def bad_get_events(*args):
+ def bad_get_events(*args: Any) -> object:
if threading.current_thread() is not t:
raise ValueError("oh no!")
else:
@@ -536,7 +555,7 @@ def test_guest_mode_ki() -> None:
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
# Check SIGINT in Trio func and in host func
- async def trio_main(in_host) -> None:
+ async def trio_main(in_host: InHost) -> None:
with pytest.raises(KeyboardInterrupt):
signal_raise(signal.SIGINT)
@@ -553,7 +572,7 @@ async def trio_main(in_host) -> None:
# Also check chaining in the case where KI is injected after main exits
final_exc = KeyError("whoa")
- async def trio_main_raising(in_host):
+ async def trio_main_raising(in_host: InHost) -> NoReturn:
in_host(partial(signal_raise, signal.SIGINT))
raise final_exc
@@ -573,7 +592,7 @@ def test_guest_mode_autojump_clock_threshold_changing() -> None:
DURATION = 120
- async def trio_main(in_host) -> None:
+ async def trio_main(in_host: InHost) -> None:
assert trio.current_time() == 0
in_host(lambda: setattr(clock, "autojump_threshold", 0))
await trio.sleep(DURATION)
@@ -594,7 +613,7 @@ def test_guest_mode_asyncgens() -> None:
record = set()
- async def agen(label: str):
+ async def agen(label: str): # TODO: some asyncgenerator thing
assert sniffio.current_async_library() == label
try:
yield 1
@@ -623,6 +642,9 @@ async def trio_main() -> None:
# Ensure we don't pollute the thread-level context if run under
# an asyncio without contextvars support (3.6)
context = contextvars.copy_context()
- context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True)
+ if TYPE_CHECKING:
+ aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True)
+ # this type error is a bug in typeshed or mypy, as it's equivalent to the above line
+ context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True) # type: ignore[arg-type]
assert record == {("asyncio", "asyncio"), ("trio", "trio")}
diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py
index 0241917b91..2b8d913948 100644
--- a/trio/_core/_tests/test_thread_cache.py
+++ b/trio/_core/_tests/test_thread_cache.py
@@ -4,7 +4,7 @@
import time
from contextlib import contextmanager
from queue import Queue
-from typing import NoReturn
+from typing import Iterator, NoReturn
import pytest
from outcome import Outcome
@@ -21,7 +21,7 @@ def test_thread_cache_basics() -> None:
def fn() -> NoReturn:
raise RuntimeError("hi")
- def deliver(outcome) -> None:
+ def deliver(outcome: Outcome) -> None:
q.put(outcome)
start_thread_soon(fn, deliver)
@@ -43,7 +43,7 @@ def __del__(self) -> None:
q = Queue[Outcome]()
- def deliver(outcome) -> None:
+ def deliver(outcome: Outcome) -> None:
q.put(outcome)
start_thread_soon(del_me(), deliver)
@@ -74,7 +74,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None:
seen_threads = set()
done = threading.Event()
- def deliver(n, _) -> None:
+ def deliver(n: int, _: object) -> None:
print(n)
seen_threads.add(threading.current_thread())
if n == 0:
@@ -106,7 +106,7 @@ def test_idle_threads_exit(monkeypatch: MonkeyPatch) -> None:
@contextmanager
-def _join_started_threads():
+def _join_started_threads() -> Iterator[None]:
before = frozenset(threading.enumerate())
try:
yield
@@ -140,7 +140,7 @@ def __init__(self) -> None:
self._lock = threading.Lock()
self._counter = 3
- def acquire(self, timeout=-1) -> bool:
+ def acquire(self, timeout: int = -1) -> bool:
got_it = self._lock.acquire(timeout=timeout)
if timeout == -1:
return True
@@ -170,13 +170,13 @@ def release(self) -> None:
tc.start_thread_soon(lambda: None, lambda _: None)
-def test_raise_in_deliver(capfd) -> None:
+def test_raise_in_deliver(capfd: pytest.CaptureFixture[str]) -> None:
seen_threads = set()
def track_threads() -> None:
seen_threads.add(threading.current_thread())
- def deliver(_):
+ def deliver(_: object) -> NoReturn:
done.set()
raise RuntimeError("don't do this")
diff --git a/trio/_ssl.py b/trio/_ssl.py
index 21c4deabb3..073d7e3812 100644
--- a/trio/_ssl.py
+++ b/trio/_ssl.py
@@ -4,7 +4,7 @@
import ssl as _stdlib_ssl
from collections.abc import Awaitable, Callable
from enum import Enum as _Enum
-from typing import Any, ClassVar, Final as TFinal, TypeVar
+from typing import Any, ClassVar, Final as TFinal, Generic, TypeVar
import trio
@@ -239,9 +239,12 @@ def done(self) -> bool:
_State = _Enum("_State", ["OK", "BROKEN", "CLOSED"])
+# TODO: variance
+T_Stream = TypeVar("T_Stream", bound=Stream)
+
@final
-class SSLStream(Stream):
+class SSLStream(Stream, Generic[T_Stream]):
r"""Encrypted communication using SSL/TLS.
:class:`SSLStream` wraps an arbitrary :class:`~trio.abc.Stream`, and
@@ -339,14 +342,14 @@ class SSLStream(Stream):
# SSLListener.__init__, and maybe the open_ssl_over_tcp_* helpers.
def __init__(
self,
- transport_stream: Stream,
+ transport_stream: T_Stream,
ssl_context: _stdlib_ssl.SSLContext,
*,
server_hostname: str | bytes | None = None,
server_side: bool = False,
https_compatible: bool = False,
) -> None:
- self.transport_stream: Stream = transport_stream
+ self.transport_stream: T_Stream = transport_stream
self._state = _State.OK
self._https_compatible = https_compatible
self._outgoing = _stdlib_ssl.MemoryBIO()
diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py
index 7855aede6f..bcfc450362 100644
--- a/trio/_tests/test_exports.py
+++ b/trio/_tests/test_exports.py
@@ -12,7 +12,7 @@
from collections.abc import Iterator
from pathlib import Path
from types import ModuleType
-from typing import Protocol
+from typing import Iterable, Protocol
import attrs
import pytest
@@ -126,10 +126,10 @@ def iter_modules(
# https://github.com/pypa/setuptools/issues/3274
"ignore:module 'sre_constants' is deprecated:DeprecationWarning",
)
-def test_static_tool_sees_all_symbols(tool, modname: str, tmp_path) -> None:
+def test_static_tool_sees_all_symbols(tool: str, modname: str, tmp_path: Path) -> None:
module = importlib.import_module(modname)
- def no_underscores(symbols):
+ def no_underscores(symbols: Iterable[str]) -> set[str]:
return {symbol for symbol in symbols if not symbol.startswith("_")}
runtime_names = no_underscores(dir(module))
@@ -278,7 +278,7 @@ def test_static_tool_sees_class_members(
module = PUBLIC_MODULES[PUBLIC_MODULE_NAMES.index(module_name)]
# ignore hidden, but not dunder, symbols
- def no_hidden(symbols):
+ def no_hidden(symbols: Iterable[str]) -> set[str]:
return {
symbol
for symbol in symbols
@@ -316,7 +316,7 @@ def no_hidden(symbols):
# skip a bunch of file-system activity (probably can un-memoize?)
@functools.lru_cache
- def lookup_symbol(symbol):
+ def lookup_symbol(symbol: str) -> dict[str, str]:
topname, *modname, name = symbol.split(".")
version = next(cache.glob("3.*/"))
mod_cache = version / topname
@@ -333,7 +333,7 @@ def lookup_symbol(symbol):
mod_cache = mod_cache / (modname[-1] + ".data.json")
with mod_cache.open() as f:
- return json.loads(f.read())["names"][name]
+ return json.loads(f.read())["names"][name] # type: ignore[no-any-return]
errors: dict[str, object] = {}
for class_name, class_ in module.__dict__.items():
@@ -441,6 +441,16 @@ def lookup_symbol(symbol):
extra = {e for e in extra if not e.endswith("AttrsAttributes__")}
assert len(extra) == before - 1
+ import enum
+
+ # mypy does not see these attributes in Enum subclasses
+ if (
+ tool == "mypy"
+ and enum.Enum in class_.__mro__
+ and sys.version_info >= (3, 11)
+ ):
+ extra.difference_update({"__copy__", "__deepcopy__"})
+
# TODO: this *should* be visible via `dir`!!
if tool == "mypy" and class_ == trio.Nursery:
extra.remove("cancel_scope")
diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py
index 2e333068ec..4a267222a0 100644
--- a/trio/_tests/test_highlevel_serve_listeners.py
+++ b/trio/_tests/test_highlevel_serve_listeners.py
@@ -1,7 +1,6 @@
from __future__ import annotations
import errno
-import sys
from functools import partial
from typing import Awaitable, Callable, NoReturn
@@ -20,9 +19,6 @@
wait_all_tasks_blocked,
)
-if sys.version_info < (3, 11):
- from exceptiongroup import ExceptionGroup
-
# types are somewhat tentative - I just bruteforced them until I got something that didn't
# give errors
TypeThing = StapledStream[MemorySendStream, MemoryReceiveStream]
@@ -138,7 +134,8 @@ async def raise_EMFILE() -> NoReturn:
assert len(caplog.records) == 10
for record in caplog.records:
assert "retrying" in record.msg
- assert isinstance(record.exc_info, ExceptionGroup)
+ assert record.exc_info is not None
+ assert isinstance(record.exc_info[1], OSError)
assert record.exc_info[1].errno == errno.EMFILE
diff --git a/trio/_tests/test_highlevel_ssl_helpers.py b/trio/_tests/test_highlevel_ssl_helpers.py
index afb5f30a6b..5a06279204 100644
--- a/trio/_tests/test_highlevel_ssl_helpers.py
+++ b/trio/_tests/test_highlevel_ssl_helpers.py
@@ -1,12 +1,16 @@
from __future__ import annotations
from functools import partial
+from socket import AddressFamily, SocketKind
+from ssl import SSLContext
+from typing import Any, NoReturn
import attr
import pytest
import trio
import trio.testing
+from trio.abc import Stream
from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM
from .._highlevel_socket import SocketListener
@@ -21,7 +25,7 @@
from .test_ssl import SERVER_CTX, client_ctx # noqa: F401
-async def echo_handler(stream) -> None:
+async def echo_handler(stream: Stream) -> None:
async with stream:
try:
while True:
@@ -37,19 +41,35 @@ async def echo_handler(stream) -> None:
# you ask for.
@attr.s
class FakeHostnameResolver(trio.abc.HostnameResolver):
- sockaddr = attr.ib()
-
- async def getaddrinfo(self, *args):
+ sockaddr: tuple[str, int] | tuple[str, int, int, int] = attr.ib()
+
+ async def getaddrinfo(
+ self,
+ host: bytes | str | None,
+ port: bytes | str | int | None,
+ family: int = 0,
+ type: int = 0,
+ proto: int = 0,
+ flags: int = 0,
+ ) -> list[
+ tuple[
+ AddressFamily,
+ SocketKind,
+ int,
+ str,
+ tuple[str, int] | tuple[str, int, int, int],
+ ]
+ ]:
return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)]
- async def getnameinfo(self, *args): # pragma: no cover
+ async def getnameinfo(self, *args: Any) -> NoReturn: # pragma: no cover
raise NotImplementedError
# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
# using noqa because linters don't understand how pytest fixtures work.
async def test_open_ssl_over_tcp_stream_and_everything_else(
- client_ctx, # noqa: F811 # linters doesn't understand fixture
+ client_ctx: SSLContext, # noqa: F811 # linters doesn't understand fixture
) -> None:
async with trio.open_nursery() as nursery:
# TODO: the types are *very* funky here, this seems like an error in some signature
diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py
index 53943cabce..5dece46e54 100644
--- a/trio/_tests/test_ssl.py
+++ b/trio/_tests/test_ssl.py
@@ -7,11 +7,18 @@
import threading
from contextlib import asynccontextmanager, contextmanager
from functools import partial
+from ssl import SSLContext
+from typing import Any, AsyncIterator, Iterator, NoReturn
import pytest
+from typing_extensions import TypeAlias
+from trio import StapledStream
from trio._core import MockClock
+from trio._ssl import T_Stream
from trio._tests.pytest_plugin import skip_if_optional_else_raise
+from trio.abc import ReceiveStream, SendStream
+from trio.testing import MemoryReceiveStream, MemorySendStream
try:
import trustme
@@ -24,6 +31,7 @@
from .. import _core, socket as tsocket
from .._abc import Stream
from .._core import BrokenResourceError, ClosedResourceError
+from .._core._run import CancelScope
from .._core._tests.tutil import slow
from .._highlevel_generic import aclose_forcefully
from .._highlevel_open_tcp_stream import open_tcp_stream
@@ -73,7 +81,7 @@
# downgrade on the server side. "tls12" means we refuse to negotiate TLS
# 1.3, so we'll almost certainly use TLS 1.2.
@pytest.fixture(scope="module", params=["tls13", "tls12"])
-def client_ctx(request):
+def client_ctx(request: pytest.FixtureRequest) -> ssl.SSLContext:
ctx = ssl.create_default_context()
if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
@@ -90,7 +98,9 @@ def client_ctx(request):
# The blocking socket server.
-def ssl_echo_serve_sync(sock, *, expect_fail: bool = False):
+def ssl_echo_serve_sync(
+ sock: stdlib_socket.socket, *, expect_fail: bool = False
+) -> None:
try:
wrapped = SERVER_CTX.wrap_socket(
sock, server_side=True, suppress_ragged_eofs=False
@@ -142,8 +152,8 @@ def ssl_echo_serve_sync(sock, *, expect_fail: bool = False):
# Fixture that gives a raw socket connected to a trio-test-1 echo server
# (running in a thread). Useful for testing making connections with different
# SSLContexts.
-@asynccontextmanager
-async def ssl_echo_server_raw(**kwargs):
+@asynccontextmanager # type: ignore[misc] # decorated contains Any
+async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]:
a, b = stdlib_socket.socketpair()
async with trio.open_nursery() as nursery:
# Exiting the 'with a, b' context manager closes the sockets, which
@@ -159,8 +169,10 @@ async def ssl_echo_server_raw(**kwargs):
# Fixture that gives a properly set up SSLStream connected to a trio-test-1
# echo server (running in a thread)
-@asynccontextmanager
-async def ssl_echo_server(client_ctx, **kwargs):
+@asynccontextmanager # type: ignore[misc] # decorated contains Any
+async def ssl_echo_server(
+ client_ctx: SSLContext, **kwargs: Any
+) -> AsyncIterator[SSLStream[Stream]]:
async with ssl_echo_server_raw(**kwargs) as sock:
yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org")
@@ -171,7 +183,7 @@ async def ssl_echo_server(client_ctx, **kwargs):
# jakkdl: it seems to implement all the abstract methods (now), so I made it inherit
# from Stream for the sake of typechecking.
class PyOpenSSLEchoStream(Stream):
- def __init__(self, sleeper=None) -> None:
+ def __init__(self, sleeper: None = None) -> None:
ctx = SSL.Context(SSL.SSLv23_METHOD)
# TLS 1.3 removes renegotiation support. Which is great for them, but
# we still have to support versions before that, and that means we
@@ -221,7 +233,7 @@ def __init__(self, sleeper=None) -> None:
if sleeper is None:
- async def no_op_sleeper(_) -> None:
+ async def no_op_sleeper(_: object) -> None:
return
self.sleeper = no_op_sleeper
@@ -231,7 +243,7 @@ async def no_op_sleeper(_) -> None:
async def aclose(self) -> None:
self._conn.bio_shutdown()
- def renegotiate_pending(self):
+ def renegotiate_pending(self) -> bool:
return self._conn.renegotiate_pending()
def renegotiate(self) -> None:
@@ -245,7 +257,7 @@ async def wait_send_all_might_not_block(self) -> None:
await _core.checkpoint()
await self.sleeper("wait_send_all_might_not_block")
- async def send_all(self, data) -> None:
+ async def send_all(self, data: bytes) -> None:
print(" --> transport_stream.send_all")
with self._send_all_conflict_detector:
await _core.checkpoint()
@@ -268,7 +280,7 @@ async def send_all(self, data) -> None:
await self.sleeper("send_all")
print(" <-- transport_stream.send_all finished")
- async def receive_some(self, nbytes=None):
+ async def receive_some(self, nbytes: int | None = None) -> bytes:
print(" --> transport_stream.receive_some")
if nbytes is None:
nbytes = 65536 # arbitrary
@@ -360,20 +372,22 @@ async def test_PyOpenSSLEchoStream_gives_resource_busy_errors() -> None:
assert "simultaneous" in str(excinfo.value)
-@contextmanager
-def virtual_ssl_echo_server(client_ctx, **kwargs):
+@contextmanager # type: ignore[misc] # decorated contains Any
+def virtual_ssl_echo_server(
+ client_ctx: SSLContext, **kwargs: Any
+) -> Iterator[SSLStream[PyOpenSSLEchoStream]]:
fakesock = PyOpenSSLEchoStream(**kwargs)
yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org")
def ssl_wrap_pair(
- client_ctx,
- client_transport,
- server_transport,
+ client_ctx: SSLContext,
+ client_transport: T_Stream,
+ server_transport: T_Stream,
*,
- client_kwargs={},
- server_kwargs={},
-):
+ client_kwargs: dict[str, Any] = {},
+ server_kwargs: dict[str, Any] = {},
+) -> tuple[SSLStream[T_Stream], SSLStream[T_Stream]]:
client_ssl = SSLStream(
client_transport,
client_ctx,
@@ -386,12 +400,22 @@ def ssl_wrap_pair(
return client_ssl, server_ssl
-def ssl_memory_stream_pair(client_ctx, **kwargs):
+MemoryStapledStream: TypeAlias = StapledStream[MemorySendStream, MemoryReceiveStream]
+
+
+def ssl_memory_stream_pair(
+ client_ctx: SSLContext, **kwargs: Any
+) -> tuple[SSLStream[MemoryStapledStream], SSLStream[MemoryStapledStream],]:
client_transport, server_transport = memory_stream_pair()
return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs)
-def ssl_lockstep_stream_pair(client_ctx, **kwargs):
+MyStapledStream: TypeAlias = StapledStream[SendStream, ReceiveStream]
+
+
+def ssl_lockstep_stream_pair(
+ client_ctx: SSLContext, **kwargs: Any
+) -> tuple[SSLStream[MyStapledStream], SSLStream[MyStapledStream],]:
client_transport, server_transport = lockstep_stream_pair()
return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs)
@@ -399,7 +423,7 @@ def ssl_lockstep_stream_pair(client_ctx, **kwargs):
# Simple smoke test for handshake/send/receive/shutdown talking to a
# synchronous server, plus make sure that we do the bare minimum of
# certificate checking (even though this is really Python's responsibility)
-async def test_ssl_client_basics(client_ctx) -> None:
+async def test_ssl_client_basics(client_ctx: SSLContext) -> None:
# Everything OK
async with ssl_echo_server(client_ctx) as s:
assert not s.server_side
@@ -425,7 +449,7 @@ async def test_ssl_client_basics(client_ctx) -> None:
assert isinstance(excinfo.value.__cause__, ssl.CertificateError)
-async def test_ssl_server_basics(client_ctx) -> None:
+async def test_ssl_server_basics(client_ctx: SSLContext) -> None:
a, b = stdlib_socket.socketpair()
with a, b:
server_sock = tsocket.from_stdlib_socket(b)
@@ -455,7 +479,7 @@ def client() -> None:
t.join()
-async def test_attributes(client_ctx) -> None:
+async def test_attributes(client_ctx: SSLContext) -> None:
async with ssl_echo_server_raw(expect_fail=True) as sock:
good_ctx = client_ctx
bad_ctx = ssl.create_default_context()
@@ -524,7 +548,7 @@ async def test_attributes(client_ctx) -> None:
# I begin to see why HTTP/2 forbids renegotiation and TLS 1.3 removes it...
-async def test_full_duplex_basics(client_ctx) -> None:
+async def test_full_duplex_basics(client_ctx: SSLContext) -> None:
CHUNKS = 30
CHUNK_SIZE = 32768
EXPECTED = CHUNKS * CHUNK_SIZE
@@ -532,7 +556,7 @@ async def test_full_duplex_basics(client_ctx) -> None:
sent = bytearray()
received = bytearray()
- async def sender(s) -> None:
+ async def sender(s: Stream) -> None:
nonlocal sent
for i in range(CHUNKS):
print(i)
@@ -540,7 +564,7 @@ async def sender(s) -> None:
sent += chunk
await s.send_all(chunk)
- async def receiver(s) -> None:
+ async def receiver(s: Stream) -> None:
nonlocal received
while len(received) < EXPECTED:
chunk = await s.receive_some(CHUNK_SIZE // 2)
@@ -561,10 +585,9 @@ async def receiver(s) -> None:
assert sent == received
-async def test_renegotiation_simple(client_ctx) -> None:
+async def test_renegotiation_simple(client_ctx: SSLContext) -> None:
with virtual_ssl_echo_server(client_ctx) as s:
await s.do_handshake()
-
s.transport_stream.renegotiate()
await s.send_all(b"a")
assert await s.receive_some(1) == b"a"
@@ -580,7 +603,9 @@ async def test_renegotiation_simple(client_ctx) -> None:
@slow
-async def test_renegotiation_randomized(mock_clock: MockClock, client_ctx) -> None:
+async def test_renegotiation_randomized(
+ mock_clock: MockClock, client_ctx: SSLContext
+) -> None:
# The only blocking things in this function are our random sleeps, so 0 is
# a good threshold.
mock_clock.autojump_threshold = 0
@@ -589,7 +614,7 @@ async def test_renegotiation_randomized(mock_clock: MockClock, client_ctx) -> No
r = random.Random(0)
- async def sleeper(_) -> None:
+ async def sleeper(_: object) -> None:
await trio.sleep(r.uniform(0, 10))
async def clear() -> None:
@@ -600,13 +625,13 @@ async def clear() -> None:
await expect(b"-")
print("-- clear --")
- async def send(byte) -> None:
+ async def send(byte: bytes) -> None:
await s.transport_stream.sleeper("outer send")
print("calling SSLStream.send_all", byte)
with assert_checkpoints():
await s.send_all(byte)
- async def expect(expected) -> None:
+ async def expect(expected: bytes) -> None:
await s.transport_stream.sleeper("expect")
print("calling SSLStream.receive_some, expecting", expected)
assert len(expected) == 1
@@ -652,7 +677,7 @@ async def expect(expected) -> None:
# and wait_send_all_might_not_block comes in.
# Our receive_some() call will get stuck when it hits send_all
- async def sleeper_with_slow_send_all(method) -> None:
+ async def sleeper_with_slow_send_all(method: str) -> None:
if method == "send_all":
await trio.sleep(100000)
@@ -676,7 +701,7 @@ async def sleep_then_wait_writable() -> None:
# 2) Same, but now wait_send_all_might_not_block is stuck when
# receive_some tries to send.
- async def sleeper_with_slow_wait_writable_and_expect(method) -> None:
+ async def sleeper_with_slow_wait_writable_and_expect(method: str) -> None:
if method == "wait_send_all_might_not_block":
await trio.sleep(100000)
elif method == "expect":
@@ -696,7 +721,7 @@ async def sleeper_with_slow_wait_writable_and_expect(method) -> None:
await s.aclose()
-async def test_resource_busy_errors(client_ctx) -> None:
+async def test_resource_busy_errors(client_ctx: SSLContext) -> None:
async def do_send_all() -> None:
with assert_checkpoints():
await s.send_all(b"x")
@@ -765,7 +790,7 @@ async def send_all(self, data: bytes | bytearray | memoryview) -> None:
os.name == "nt" and sys.version_info >= (3, 10),
reason="frequently fails on Windows + Python 3.10",
)
-async def test_checkpoints(client_ctx) -> None:
+async def test_checkpoints(client_ctx: SSLContext) -> None:
async with ssl_echo_server(client_ctx) as s:
with assert_checkpoints():
await s.do_handshake()
@@ -794,7 +819,7 @@ async def test_checkpoints(client_ctx) -> None:
await s.aclose()
-async def test_send_all_empty_string(client_ctx) -> None:
+async def test_send_all_empty_string(client_ctx: SSLContext) -> None:
async with ssl_echo_server(client_ctx) as s:
await s.do_handshake()
@@ -811,15 +836,23 @@ async def test_send_all_empty_string(client_ctx) -> None:
@pytest.mark.parametrize("https_compatible", [False, True])
-async def test_SSLStream_generic(client_ctx, https_compatible) -> None:
- async def stream_maker():
+async def test_SSLStream_generic(
+ client_ctx: SSLContext, https_compatible: bool
+) -> None:
+ async def stream_maker() -> tuple[
+ SSLStream[MemoryStapledStream],
+ SSLStream[MemoryStapledStream],
+ ]:
return ssl_memory_stream_pair(
client_ctx,
client_kwargs={"https_compatible": https_compatible},
server_kwargs={"https_compatible": https_compatible},
)
- async def clogged_stream_maker():
+ async def clogged_stream_maker() -> tuple[
+ SSLStream[MyStapledStream],
+ SSLStream[MyStapledStream],
+ ]:
client, server = ssl_lockstep_stream_pair(client_ctx)
# If we don't do handshakes up front, then we run into a problem in
# the following situation:
@@ -835,7 +868,7 @@ async def clogged_stream_maker():
await check_two_way_stream(stream_maker, clogged_stream_maker)
-async def test_unwrap(client_ctx) -> None:
+async def test_unwrap(client_ctx: SSLContext) -> None:
client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
client_transport = client_ssl.transport_stream
server_transport = server_ssl.transport_stream
@@ -889,7 +922,7 @@ async def server() -> None:
nursery.start_soon(server)
-async def test_closing_nice_case(client_ctx) -> None:
+async def test_closing_nice_case(client_ctx: SSLContext) -> None:
# the nice case: graceful closes all around
client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
@@ -951,14 +984,14 @@ async def expect_eof_server() -> None:
nursery.start_soon(expect_eof_server)
-async def test_send_all_fails_in_the_middle(client_ctx) -> None:
+async def test_send_all_fails_in_the_middle(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
async with _core.open_nursery() as nursery:
nursery.start_soon(client.do_handshake)
nursery.start_soon(server.do_handshake)
- async def bad_hook():
+ async def bad_hook() -> NoReturn:
raise KeyError
client.transport_stream.send_stream.send_all_hook = bad_hook
@@ -982,7 +1015,7 @@ def close_hook() -> None:
assert closed == 2
-async def test_ssl_over_ssl(client_ctx) -> None:
+async def test_ssl_over_ssl(client_ctx: SSLContext) -> None:
client_0, server_0 = memory_stream_pair()
client_1 = SSLStream(
@@ -1008,7 +1041,7 @@ async def server() -> None:
nursery.start_soon(server)
-async def test_ssl_bad_shutdown(client_ctx) -> None:
+async def test_ssl_bad_shutdown(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
async with _core.open_nursery() as nursery:
@@ -1025,7 +1058,7 @@ async def test_ssl_bad_shutdown(client_ctx) -> None:
await server.aclose()
-async def test_ssl_bad_shutdown_but_its_ok(client_ctx) -> None:
+async def test_ssl_bad_shutdown_but_its_ok(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(
client_ctx,
server_kwargs={"https_compatible": True},
@@ -1064,7 +1097,7 @@ async def test_ssl_handshake_failure_during_aclose() -> None:
await s.aclose()
-async def test_ssl_only_closes_stream_once(client_ctx) -> None:
+async def test_ssl_only_closes_stream_once(client_ctx: SSLContext) -> None:
# We used to have a bug where if transport_stream.aclose() raised an
# error, we would call it again. This checks that that's fixed.
client, server = ssl_memory_stream_pair(client_ctx)
@@ -1076,8 +1109,9 @@ async def test_ssl_only_closes_stream_once(client_ctx) -> None:
client_orig_close_hook = client.transport_stream.send_stream.close_hook
transport_close_count = 0
- def close_hook():
+ def close_hook() -> NoReturn:
nonlocal transport_close_count
+ assert client_orig_close_hook is not None
client_orig_close_hook()
transport_close_count += 1
raise KeyError
@@ -1089,7 +1123,7 @@ def close_hook():
assert transport_close_count == 1
-async def test_ssl_https_compatibility_disagreement(client_ctx) -> None:
+async def test_ssl_https_compatibility_disagreement(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(
client_ctx,
server_kwargs={"https_compatible": False},
@@ -1113,7 +1147,7 @@ async def receive_and_expect_error() -> None:
nursery.start_soon(receive_and_expect_error)
-async def test_https_mode_eof_before_handshake(client_ctx) -> None:
+async def test_https_mode_eof_before_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(
client_ctx,
server_kwargs={"https_compatible": True},
@@ -1128,10 +1162,10 @@ async def server_expect_clean_eof() -> None:
nursery.start_soon(server_expect_clean_eof)
-async def test_send_error_during_handshake(client_ctx) -> None:
+async def test_send_error_during_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
- async def bad_hook():
+ async def bad_hook() -> NoReturn:
raise KeyError
client.transport_stream.send_stream.send_all_hook = bad_hook
@@ -1145,15 +1179,15 @@ async def bad_hook():
await client.do_handshake()
-async def test_receive_error_during_handshake(client_ctx) -> None:
+async def test_receive_error_during_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
- async def bad_hook():
+ async def bad_hook() -> NoReturn:
raise KeyError
client.transport_stream.receive_stream.receive_some_hook = bad_hook
- async def client_side(cancel_scope) -> None:
+ async def client_side(cancel_scope: CancelScope) -> None:
with pytest.raises(KeyError):
with assert_checkpoints():
await client.do_handshake()
@@ -1168,7 +1202,7 @@ async def client_side(cancel_scope) -> None:
await client.do_handshake()
-async def test_selected_alpn_protocol_before_handshake(client_ctx) -> None:
+async def test_selected_alpn_protocol_before_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
with pytest.raises(NeedHandshakeError):
@@ -1178,7 +1212,7 @@ async def test_selected_alpn_protocol_before_handshake(client_ctx) -> None:
server.selected_alpn_protocol()
-async def test_selected_alpn_protocol_when_not_set(client_ctx) -> None:
+async def test_selected_alpn_protocol_when_not_set(client_ctx: SSLContext) -> None:
# ALPN protocol still returns None when it's not set,
# instead of raising an exception
client, server = ssl_memory_stream_pair(client_ctx)
@@ -1193,7 +1227,7 @@ async def test_selected_alpn_protocol_when_not_set(client_ctx) -> None:
assert client.selected_alpn_protocol() == server.selected_alpn_protocol()
-async def test_selected_npn_protocol_before_handshake(client_ctx) -> None:
+async def test_selected_npn_protocol_before_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
with pytest.raises(NeedHandshakeError):
@@ -1207,7 +1241,7 @@ async def test_selected_npn_protocol_before_handshake(client_ctx) -> None:
r"ignore: ssl module. NPN is deprecated, use ALPN instead:UserWarning",
r"ignore:ssl NPN is deprecated, use ALPN instead:DeprecationWarning",
)
-async def test_selected_npn_protocol_when_not_set(client_ctx) -> None:
+async def test_selected_npn_protocol_when_not_set(client_ctx: SSLContext) -> None:
# NPN protocol still returns None when it's not set,
# instead of raising an exception
client, server = ssl_memory_stream_pair(client_ctx)
@@ -1222,7 +1256,7 @@ async def test_selected_npn_protocol_when_not_set(client_ctx) -> None:
assert client.selected_npn_protocol() == server.selected_npn_protocol()
-async def test_get_channel_binding_before_handshake(client_ctx) -> None:
+async def test_get_channel_binding_before_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
with pytest.raises(NeedHandshakeError):
@@ -1232,7 +1266,7 @@ async def test_get_channel_binding_before_handshake(client_ctx) -> None:
server.get_channel_binding()
-async def test_get_channel_binding_after_handshake(client_ctx) -> None:
+async def test_get_channel_binding_after_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
async with _core.open_nursery() as nursery:
@@ -1245,7 +1279,7 @@ async def test_get_channel_binding_after_handshake(client_ctx) -> None:
assert client.get_channel_binding() == server.get_channel_binding()
-async def test_getpeercert(client_ctx) -> None:
+async def test_getpeercert(client_ctx: SSLContext) -> None:
# Make sure we're not affected by https://bugs.python.org/issue29334
client, server = ssl_memory_stream_pair(client_ctx)
@@ -1258,8 +1292,10 @@ async def test_getpeercert(client_ctx) -> None:
assert ("DNS", "trio-test-1.example.org") in client.getpeercert()["subjectAltName"]
-async def test_SSLListener(client_ctx) -> None:
- async def setup(**kwargs):
+async def test_SSLListener(client_ctx: SSLContext) -> None:
+ async def setup(
+ **kwargs: Any,
+ ) -> tuple[tsocket.SocketType, SSLListener, SSLStream[SocketStream]]:
listen_sock = tsocket.socket()
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(1)
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index 3ebf2589bc..b0c93cdf7a 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -11,7 +11,7 @@
from typing import TYPE_CHECKING, AsyncIterator
import pytest
-from pytest import MonkeyPatch
+from pytest import MonkeyPatch, WarningsRecorder
from .. import (
ClosedResourceError,
@@ -153,7 +153,7 @@ async def test_multi_wait(background_process) -> None:
# Test for deprecated 'async with process:' semantics
-async def test_async_with_basics_deprecated(recwarn) -> None:
+async def test_async_with_basics_deprecated(recwarn: WarningsRecorder) -> None:
async with await open_process(
CAT, stdin=subprocess.PIPE, stdout=subprocess.PIPE
) as proc:
@@ -168,7 +168,7 @@ async def test_async_with_basics_deprecated(recwarn) -> None:
# Test for deprecated 'async with process:' semantics
-async def test_kill_when_context_cancelled(recwarn) -> None:
+async def test_kill_when_context_cancelled(recwarn: WarningsRecorder) -> None:
with move_on_after(100) as scope:
async with await open_process(SLEEP(10)) as proc:
assert proc.poll() is None
diff --git a/trio/_tests/test_sync.py b/trio/_tests/test_sync.py
index 91c865e085..448ead15da 100644
--- a/trio/_tests/test_sync.py
+++ b/trio/_tests/test_sync.py
@@ -1,6 +1,10 @@
+from __future__ import annotations
+
import weakref
+from typing import Callable, Union
import pytest
+from typing_extensions import TypeAlias
from .. import _core
from .._sync import *
@@ -204,7 +208,7 @@ async def test_Semaphore() -> None:
record = []
- async def do_acquire(s) -> None:
+ async def do_acquire(s: Semaphore) -> None:
record.append("started")
await s.acquire()
record.append("finished")
@@ -240,7 +244,9 @@ async def test_Semaphore_bounded() -> None:
@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__)
-async def test_Lock_and_StrictFIFOLock(lockcls) -> None:
+async def test_Lock_and_StrictFIFOLock(
+ lockcls: type[Lock] | type[StrictFIFOLock],
+) -> None:
l = lockcls() # noqa
assert not l.locked()
@@ -349,7 +355,7 @@ async def test_Condition() -> None:
finished_waiters = set()
- async def waiter(i) -> None:
+ async def waiter(i: int) -> None:
async with c:
await c.wait()
finished_waiters.add(i)
@@ -489,17 +495,28 @@ def release(self) -> None:
"lock_factory", lock_factories, ids=lock_factory_names
)
+LockLike: TypeAlias = Union[
+ CapacityLimiter,
+ Semaphore,
+ Lock,
+ StrictFIFOLock,
+ ChannelLock1,
+ ChannelLock2,
+ ChannelLock3,
+]
+LockFactory: TypeAlias = Callable[[], LockLike]
+
# Spawn a bunch of workers that take a lock and then yield; make sure that
# only one worker is ever in the critical section at a time.
@generic_lock_test
-async def test_generic_lock_exclusion(lock_factory) -> None:
+async def test_generic_lock_exclusion(lock_factory: LockFactory) -> None:
LOOPS = 10
WORKERS = 5
in_critical_section = False
acquires = 0
- async def worker(lock_like) -> None:
+ async def worker(lock_like: LockLike) -> None:
nonlocal in_critical_section, acquires
for _ in range(LOOPS):
async with lock_like:
@@ -522,12 +539,12 @@ async def worker(lock_like) -> None:
# Several workers queue on the same lock; make sure they each get it, in
# order.
@generic_lock_test
-async def test_generic_lock_fifo_fairness(lock_factory) -> None:
+async def test_generic_lock_fifo_fairness(lock_factory: LockFactory) -> None:
initial_order = []
record = []
LOOPS = 5
- async def loopy(name: str, lock_like) -> None:
+ async def loopy(name: str, lock_like: LockLike) -> None:
# Record the order each task was initially scheduled in
initial_order.append(name)
for _ in range(LOOPS):
@@ -546,7 +563,9 @@ async def loopy(name: str, lock_like) -> None:
@generic_lock_test
-async def test_generic_lock_acquire_nowait_blocks_acquire(lock_factory) -> None:
+async def test_generic_lock_acquire_nowait_blocks_acquire(
+ lock_factory: LockFactory,
+) -> None:
lock_like = lock_factory()
record = []
diff --git a/trio/_tests/test_testing.py b/trio/_tests/test_testing.py
index 7461de058c..1c36567560 100644
--- a/trio/_tests/test_testing.py
+++ b/trio/_tests/test_testing.py
@@ -1,12 +1,17 @@
-# XX this should get broken up, like testing.py did
+from __future__ import annotations
+# XX this should get broken up, like testing.py did
import tempfile
import pytest
+from pytest import WarningsRecorder
+
+from trio import Nursery
+from trio.abc import ReceiveStream, SendStream
from .. import _core, sleep, socket as tsocket
from .._core._tests.tutil import can_bind_ipv6
-from .._highlevel_generic import aclose_forcefully
+from .._highlevel_generic import StapledStream, aclose_forcefully
from .._highlevel_socket import SocketListener
from ..testing import *
from ..testing._check_streams import _assert_raises
@@ -104,7 +109,7 @@ async def wait_big_cushion() -> None:
################################################################
-async def test_assert_checkpoints(recwarn) -> None:
+async def test_assert_checkpoints(recwarn: WarningsRecorder) -> None:
with assert_checkpoints():
await _core.checkpoint()
@@ -130,7 +135,7 @@ async def test_assert_checkpoints(recwarn) -> None:
await _core.cancel_shielded_checkpoint()
-async def test_assert_no_checkpoints(recwarn) -> None:
+async def test_assert_no_checkpoints(recwarn: WarningsRecorder) -> None:
with assert_no_checkpoints():
1 + 1
@@ -163,11 +168,11 @@ async def test_assert_no_checkpoints(recwarn) -> None:
async def test_Sequencer() -> None:
record = []
- def t(val) -> None:
+ def t(val: object) -> None:
print(val)
record.append(val)
- async def f1(seq) -> None:
+ async def f1(seq: Sequencer) -> None:
async with seq(1):
t(("f1", 1))
async with seq(3):
@@ -175,7 +180,7 @@ async def f1(seq) -> None:
async with seq(4):
t(("f1", 4))
- async def f2(seq) -> None:
+ async def f2(seq: Sequencer) -> None:
async with seq(0):
t(("f2", 0))
async with seq(2):
@@ -203,7 +208,7 @@ async def test_Sequencer_cancel() -> None:
record = []
seq = Sequencer()
- async def child(i) -> None:
+ async def child(i: int) -> None:
with _core.CancelScope() as scope:
if i == 1:
scope.cancel()
@@ -230,7 +235,7 @@ async def child(i) -> None:
################################################################
-async def test__assert_raises():
+async def test__assert_raises() -> None:
with pytest.raises(AssertionError):
with _assert_raises(RuntimeError):
1 + 1
@@ -273,11 +278,11 @@ async def test__UnboundeByteQueue() -> None:
with assert_checkpoints():
assert await ubq.get() == b"efghi"
- async def putter(data) -> None:
+ async def putter(data: bytes) -> None:
await wait_all_tasks_blocked()
ubq.put(data)
- async def getter(expect) -> None:
+ async def getter(expect: bytes) -> None:
with assert_checkpoints():
assert await ubq.get() == expect
@@ -320,7 +325,7 @@ async def closer() -> None:
async def test_MemorySendStream() -> None:
mss = MemorySendStream()
- async def do_send_all(data) -> None:
+ async def do_send_all(data: bytes) -> None:
with assert_checkpoints():
await mss.send_all(data)
@@ -410,7 +415,7 @@ def close_hook() -> None:
async def test_MemoryReceiveStream() -> None:
mrs = MemoryReceiveStream()
- async def do_receive_some(max_bytes):
+ async def do_receive_some(max_bytes: int | None) -> bytes:
with assert_checkpoints():
return await mrs.receive_some(max_bytes)
@@ -521,7 +526,7 @@ async def test_memory_stream_one_way_pair() -> None:
await s.send_all(b"123")
assert await r.receive_some(10) == b"123"
- async def receiver(expected) -> None:
+ async def receiver(expected: bytes) -> None:
assert await r.receive_some(10) == expected
# This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook
@@ -549,7 +554,7 @@ async def receiver(expected) -> None:
s.send_all_hook = None
await s.send_all(b"456")
- async def cancel_after_idle(nursery) -> None:
+ async def cancel_after_idle(nursery: Nursery) -> None:
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
@@ -591,31 +596,37 @@ async def receiver() -> None:
async def test_memory_streams_with_generic_tests() -> None:
- async def one_way_stream_maker():
+ async def one_way_stream_maker() -> tuple[MemorySendStream, MemoryReceiveStream]:
return memory_stream_one_way_pair()
await check_one_way_stream(one_way_stream_maker, None)
- async def half_closeable_stream_maker():
+ async def half_closeable_stream_maker() -> tuple[
+ StapledStream[MemorySendStream, MemoryReceiveStream],
+ StapledStream[MemorySendStream, MemoryReceiveStream],
+ ]:
return memory_stream_pair()
await check_half_closeable_stream(half_closeable_stream_maker, None)
async def test_lockstep_streams_with_generic_tests() -> None:
- async def one_way_stream_maker():
+ async def one_way_stream_maker() -> tuple[SendStream, ReceiveStream]:
return lockstep_stream_one_way_pair()
await check_one_way_stream(one_way_stream_maker, one_way_stream_maker)
- async def two_way_stream_maker():
+ async def two_way_stream_maker() -> tuple[
+ StapledStream[SendStream, ReceiveStream],
+ StapledStream[SendStream, ReceiveStream],
+ ]:
return lockstep_stream_pair()
await check_two_way_stream(two_way_stream_maker, two_way_stream_maker)
async def test_open_stream_to_socket_listener() -> None:
- async def check(listener) -> None:
+ async def check(listener: SocketListener) -> None:
async with listener:
client_stream = await open_stream_to_socket_listener(listener)
async with client_stream:
diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py
index 0fe5d8dc48..80b870445d 100644
--- a/trio/_tests/test_threads.py
+++ b/trio/_tests/test_threads.py
@@ -360,7 +360,9 @@ async def child(q, cancellable):
# Make sure that if trio.run exits, and then the thread finishes, then that's
# handled gracefully. (Requires that the thread result machinery be prepared
# for call_soon to raise RunFinishedError.)
-def test_run_in_worker_thread_abandoned(capfd, monkeypatch: MonkeyPatch) -> None:
+def test_run_in_worker_thread_abandoned(
+ capfd: pytest.CaptureFixture[str], monkeypatch: MonkeyPatch
+) -> None:
monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01)
q1 = stdlib_queue.Queue[None]()
diff --git a/trio/_tests/test_util.py b/trio/_tests/test_util.py
index f074a06096..01a7503e9d 100644
--- a/trio/_tests/test_util.py
+++ b/trio/_tests/test_util.py
@@ -1,6 +1,7 @@
import signal
import sys
import types
+from typing import Any, TypeVar
import pytest
@@ -23,11 +24,13 @@
)
from ..testing import wait_all_tasks_blocked
+T = TypeVar("T")
+
def test_signal_raise() -> None:
record = []
- def handler(signum, _) -> None:
+ def handler(signum: int, _: object) -> None:
record.append(signum)
old = signal.signal(signal.SIGFPE, handler)
@@ -114,9 +117,9 @@ async def f() -> None: # pragma: no cover
import asyncio
if sys.version_info < (3, 11):
-
- @asyncio.coroutine
- def generator_based_coro(): # pragma: no cover
+ # not bothering to type this one
+ @asyncio.coroutine # type: ignore[misc]
+ def generator_based_coro() -> Any: # pragma: no cover
yield from asyncio.sleep(1)
with pytest.raises(TypeError) as excinfo:
@@ -145,7 +148,7 @@ def generator_based_coro(): # pragma: no cover
assert "appears to be synchronous" in str(excinfo.value)
- async def async_gen(arg): # pragma: no cover
+ async def async_gen(_: object) -> Any: # pragma: no cover
yield
# does not give arg-type typing error
@@ -160,7 +163,7 @@ async def async_gen(arg): # pragma: no cover
def test_generic_function() -> None:
@generic_function
- def test_func(arg):
+ def test_func(arg: T) -> T:
"""Look, a docstring!"""
return arg
diff --git a/trio/_tests/test_wait_for_object.py b/trio/_tests/test_wait_for_object.py
index 44790497ed..b41bcba3a5 100644
--- a/trio/_tests/test_wait_for_object.py
+++ b/trio/_tests/test_wait_for_object.py
@@ -153,7 +153,7 @@ async def test_WaitForSingleObject() -> None:
# Not a handle
with pytest.raises(TypeError):
- await WaitForSingleObject("not a handle") # Wrong type
+ await WaitForSingleObject("not a handle") # type: ignore[arg-type] # Wrong type
# with pytest.raises(OSError):
# await WaitForSingleObject(99) # If you're unlucky, it actually IS a handle :(
print("test_WaitForSingleObject not a handle OK")
diff --git a/trio/_tests/tools/test_gen_exports.py b/trio/_tests/tools/test_gen_exports.py
index ee51827a59..4bd8b7c089 100644
--- a/trio/_tests/tools/test_gen_exports.py
+++ b/trio/_tests/tools/test_gen_exports.py
@@ -1,5 +1,6 @@
import ast
import sys
+from pathlib import Path
import pytest
@@ -91,7 +92,7 @@ def test_create_pass_through_args() -> None:
@skip_lints
@pytest.mark.parametrize("imports", ["", IMPORT_1, IMPORT_2, IMPORT_3])
-def test_process(tmp_path, imports) -> None:
+def test_process(tmp_path: Path, imports: str) -> None:
try:
import black # noqa: F401
# there's no dedicated CI run that has astor+isort, but lacks black.
@@ -123,7 +124,7 @@ def test_process(tmp_path, imports) -> None:
@skip_lints
-def test_run_black(tmp_path) -> None:
+def test_run_black(tmp_path: Path) -> None:
"""Test that processing properly fails if black does."""
try:
import black # noqa: F401
@@ -140,7 +141,7 @@ def test_run_black(tmp_path) -> None:
@skip_lints
-def test_run_ruff(tmp_path) -> None:
+def test_run_ruff(tmp_path: Path) -> None:
"""Test that processing properly fails if black does."""
try:
import ruff # noqa: F401
@@ -166,7 +167,7 @@ def test_run_ruff(tmp_path) -> None:
@skip_lints
-def test_lint_failure(tmp_path) -> None:
+def test_lint_failure(tmp_path: Path) -> None:
"""Test that processing properly fails if black or ruff does."""
try:
import black # noqa: F401
From fdd3df3169c7039c0add1e65dc61d270fd54f6e0 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 18 Oct 2023 15:11:51 +0200
Subject: [PATCH 09/35] type test_file_io
---
pyproject.toml | 1 -
trio/_tests/test_file_io.py | 40 +++++++++++++++++++++++++------------
2 files changed, 27 insertions(+), 14 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 831d577e65..0e4b7671f9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -87,7 +87,6 @@ module = [
"trio/_core/_tests/test_guest_mode",
"trio/_core/_tests/test_ki",
"trio/_core/_tests/test_multierror",
-"trio/_tests/test_file_io",
"trio/_tests/test_subprocess",
"trio/_tests/test_threads",
]
diff --git a/trio/_tests/test_file_io.py b/trio/_tests/test_file_io.py
index 2ef02f1145..d438a9fb10 100644
--- a/trio/_tests/test_file_io.py
+++ b/trio/_tests/test_file_io.py
@@ -20,12 +20,12 @@ def path(tmp_path: pathlib.Path) -> str:
@pytest.fixture
-def wrapped():
+def wrapped() -> mock.Mock:
return mock.Mock(spec_set=io.StringIO)
@pytest.fixture
-def async_file(wrapped):
+def async_file(wrapped: mock.Mock) -> AsyncIOWrapper[mock.Mock]:
return trio.wrap_file(wrapped)
@@ -54,11 +54,15 @@ def write(self) -> None: # pragma: no cover
trio.wrap_file(FakeFile())
-def test_wrapped_property(async_file, wrapped) -> None:
+def test_wrapped_property(
+ async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock
+) -> None:
assert async_file.wrapped is wrapped
-def test_dir_matches_wrapped(async_file, wrapped) -> None:
+def test_dir_matches_wrapped(
+ async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock
+) -> None:
attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS)
# all supported attrs in wrapped should be available in async_file
@@ -122,7 +126,9 @@ def test_type_stubs_match_lists() -> None:
assert found == expected
-def test_sync_attrs_forwarded(async_file, wrapped) -> None:
+def test_sync_attrs_forwarded(
+ async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock
+) -> None:
for attr_name in _FILE_SYNC_ATTRS:
if attr_name not in dir(async_file):
continue
@@ -130,7 +136,9 @@ def test_sync_attrs_forwarded(async_file, wrapped) -> None:
assert getattr(async_file, attr_name) is getattr(wrapped, attr_name)
-def test_sync_attrs_match_wrapper(async_file, wrapped) -> None:
+def test_sync_attrs_match_wrapper(
+ async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock
+) -> None:
for attr_name in _FILE_SYNC_ATTRS:
if attr_name in dir(async_file):
continue
@@ -142,7 +150,7 @@ def test_sync_attrs_match_wrapper(async_file, wrapped) -> None:
getattr(wrapped, attr_name)
-def test_async_methods_generated_once(async_file) -> None:
+def test_async_methods_generated_once(async_file: AsyncIOWrapper[mock.Mock]) -> None:
for meth_name in _FILE_ASYNC_METHODS:
if meth_name not in dir(async_file):
continue
@@ -150,15 +158,19 @@ def test_async_methods_generated_once(async_file) -> None:
assert getattr(async_file, meth_name) is getattr(async_file, meth_name)
-def test_async_methods_signature(async_file) -> None:
+# I gave up on typing this one
+def test_async_methods_signature(async_file: AsyncIOWrapper[mock.Mock]) -> None:
# use read as a representative of all async methods
assert async_file.read.__name__ == "read"
assert async_file.read.__qualname__ == "AsyncIOWrapper.read"
+ assert async_file.read.__doc__ is not None
assert "io.StringIO.read" in async_file.read.__doc__
-async def test_async_methods_wrap(async_file, wrapped) -> None:
+async def test_async_methods_wrap(
+ async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock
+) -> None:
for meth_name in _FILE_ASYNC_METHODS:
if meth_name not in dir(async_file):
continue
@@ -176,7 +188,9 @@ async def test_async_methods_wrap(async_file, wrapped) -> None:
wrapped.reset_mock()
-async def test_async_methods_match_wrapper(async_file, wrapped) -> None:
+async def test_async_methods_match_wrapper(
+ async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock
+) -> None:
for meth_name in _FILE_ASYNC_METHODS:
if meth_name in dir(async_file):
continue
@@ -188,7 +202,7 @@ async def test_async_methods_match_wrapper(async_file, wrapped) -> None:
getattr(wrapped, meth_name)
-async def test_open(path) -> None:
+async def test_open(path: pathlib.Path) -> None:
f = await trio.open_file(path, "w")
assert isinstance(f, AsyncIOWrapper)
@@ -196,7 +210,7 @@ async def test_open(path) -> None:
await f.aclose()
-async def test_open_context_manager(path) -> None:
+async def test_open_context_manager(path: pathlib.Path) -> None:
async with await trio.open_file(path, "w") as f:
assert isinstance(f, AsyncIOWrapper)
assert not f.closed
@@ -216,7 +230,7 @@ async def test_async_iter() -> None:
assert result == expected
-async def test_aclose_cancelled(path) -> None:
+async def test_aclose_cancelled(path: pathlib.Path) -> None:
with _core.CancelScope() as cscope:
f = await trio.open_file(path, "w")
cscope.cancel()
From 55ece97dc7cd79360ac4ebab008ea17a48db9fbf Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 18 Oct 2023 15:38:49 +0200
Subject: [PATCH 10/35] type test_subprocess
---
pyproject.toml | 1 -
trio/_tests/test_subprocess.py | 90 +++++++++++++++++++++-------------
2 files changed, 55 insertions(+), 36 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 0e4b7671f9..6d4dc673af 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -87,7 +87,6 @@ module = [
"trio/_core/_tests/test_guest_mode",
"trio/_core/_tests/test_ki",
"trio/_core/_tests/test_multierror",
-"trio/_tests/test_subprocess",
"trio/_tests/test_threads",
]
disallow_any_decorated = false
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index b0c93cdf7a..3da8e48d63 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -8,10 +8,19 @@
from contextlib import asynccontextmanager
from functools import partial
from pathlib import Path as SyncPath
-from typing import TYPE_CHECKING, AsyncIterator
+from signal import Signals
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AsyncContextManager,
+ AsyncIterator,
+ Callable,
+ NoReturn,
+)
import pytest
from pytest import MonkeyPatch, WarningsRecorder
+from typing_extensions import TypeAlias
from .. import (
ClosedResourceError,
@@ -24,18 +33,21 @@
sleep,
sleep_forever,
)
+from .._abc import Stream
from .._core._tests.tutil import skip_if_fbsd_pipes_broken, slow
from ..lowlevel import open_process
-from ..testing import assert_no_checkpoints, wait_all_tasks_blocked
+from ..testing import MockClock, assert_no_checkpoints, wait_all_tasks_blocked
+
+if sys.platform == "win32":
+ SignalType: TypeAlias = None
+else:
+ SignalType: TypeAlias = Signals
-if TYPE_CHECKING:
- ...
- from signal import Signals
+SIGKILL: SignalType
+SIGTERM: SignalType
+SIGUSR1: SignalType
posix = os.name == "posix"
-SIGKILL: Signals | None
-SIGTERM: Signals | None
-SIGUSR1: Signals | None
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
from signal import SIGKILL, SIGTERM, SIGUSR1
else:
@@ -64,15 +76,15 @@ def SLEEP(seconds: int) -> list[str]:
return python(f"import time; time.sleep({seconds})")
-def got_signal(proc, sig):
+def got_signal(proc: Process, sig: SignalType) -> bool:
if posix:
return proc.returncode == -sig
else:
return proc.returncode != 0
-@asynccontextmanager
-async def open_process_then_kill(*args, **kwargs):
+@asynccontextmanager # type: ignore[misc] # Any in decorator
+async def open_process_then_kill(*args: Any, **kwargs: Any) -> AsyncIterator[Process]:
proc = await open_process(*args, **kwargs)
try:
yield proc
@@ -81,16 +93,11 @@ async def open_process_then_kill(*args, **kwargs):
await proc.wait()
-# not entirely sure about this annotation
-@asynccontextmanager
-async def run_process_in_nursery(
- *args, **kwargs
-) -> AsyncIterator[subprocess.CompletedProcess[bytes]]:
+@asynccontextmanager # type: ignore[misc] # Any in decorator
+async def run_process_in_nursery(*args: Any, **kwargs: Any) -> AsyncIterator[Process]:
async with _core.open_nursery() as nursery:
kwargs.setdefault("check", False)
- proc: subprocess.CompletedProcess[bytes] = await nursery.start(
- partial(run_process, *args, **kwargs)
- )
+ proc: Process = await nursery.start(partial(run_process, *args, **kwargs))
yield proc
nursery.cancel_scope.cancel()
@@ -101,9 +108,11 @@ async def run_process_in_nursery(
ids=["open_process", "run_process in nursery"],
)
+BackgroundProcessType: TypeAlias = Callable[..., AsyncContextManager[Process]]
+
@background_process_param
-async def test_basic(background_process) -> None:
+async def test_basic(background_process: BackgroundProcessType) -> None:
async with background_process(EXIT_TRUE) as proc:
await proc.wait()
assert isinstance(proc, Process)
@@ -120,7 +129,9 @@ async def test_basic(background_process) -> None:
@background_process_param
-async def test_auto_update_returncode(background_process) -> None:
+async def test_auto_update_returncode(
+ background_process: BackgroundProcessType,
+) -> None:
async with background_process(SLEEP(9999)) as p:
assert p.returncode is None
assert "running" in repr(p)
@@ -133,7 +144,7 @@ async def test_auto_update_returncode(background_process) -> None:
@background_process_param
-async def test_multi_wait(background_process) -> None:
+async def test_multi_wait(background_process: BackgroundProcessType) -> None:
async with background_process(SLEEP(10)) as proc:
# Check that wait (including multi-wait) tolerates being cancelled
async with _core.open_nursery() as nursery:
@@ -189,7 +200,7 @@ async def test_kill_when_context_cancelled(recwarn: WarningsRecorder) -> None:
@background_process_param
-async def test_pipes(background_process) -> None:
+async def test_pipes(background_process: BackgroundProcessType) -> None:
async with background_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=subprocess.PIPE,
@@ -199,10 +210,12 @@ async def test_pipes(background_process) -> None:
msg = b"the quick brown fox jumps over the lazy dog"
async def feed_input() -> None:
+ assert proc.stdin is not None
await proc.stdin.send_all(msg)
await proc.stdin.aclose()
- async def check_output(stream, expected) -> None:
+ async def check_output(stream: Stream, expected: bytes) -> None:
+ assert type(stream) is None
seen = bytearray()
async for chunk in stream:
seen += chunk
@@ -220,7 +233,7 @@ async def check_output(stream, expected) -> None:
@background_process_param
-async def test_interactive(background_process) -> None:
+async def test_interactive(background_process: BackgroundProcessType) -> None:
# Test some back-and-forth with a subprocess. This one works like so:
# in: 32\n
# out: 0000...0000\n (32 zeroes)
@@ -249,10 +262,10 @@ async def test_interactive(background_process) -> None:
) as proc:
newline = b"\n" if posix else b"\r\n"
- async def expect(idx: int, request) -> None:
+ async def expect(idx: int, request: int) -> None:
async with _core.open_nursery() as nursery:
- async def drain_one(stream, count: int, digit) -> None:
+ async def drain_one(stream: Stream, count: int, digit: int) -> None:
while count > 0:
result = await stream.receive_some(count)
assert result == (f"{digit}".encode() * len(result))
@@ -263,6 +276,9 @@ async def drain_one(stream, count: int, digit) -> None:
nursery.start_soon(drain_one, proc.stdout, request, idx * 2)
nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1)
+ assert proc.stdin is not None
+ assert proc.stdout is not None
+ assert proc.stderr is not None
with fail_after(5):
await proc.stdin.send_all(b"12")
await sleep(0.1)
@@ -358,13 +374,14 @@ async def test_run_with_broken_pipe() -> None:
@background_process_param
-async def test_stderr_stdout(background_process) -> None:
+async def test_stderr_stdout(background_process: BackgroundProcessType) -> None:
async with background_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as proc:
+ assert proc.stdio is not None
assert proc.stdout is not None
assert proc.stderr is None
await proc.stdio.send_all(b"1234")
@@ -438,14 +455,17 @@ async def test_errors() -> None:
@background_process_param
-async def test_signals(background_process) -> None:
- async def test_one_signal(send_it, signum) -> None:
+async def test_signals(background_process: BackgroundProcessType) -> None:
+ async def test_one_signal(
+ send_it: Callable[[Process], None], signum: signal.Signals | None
+ ) -> None:
with move_on_after(1.0) as scope:
async with background_process(SLEEP(3600)) as proc:
send_it(proc)
await proc.wait()
assert not scope.cancelled_caught
if posix:
+ assert signum is not None
assert proc.returncode == -signum
else:
assert proc.returncode != 0
@@ -465,7 +485,7 @@ async def test_one_signal(send_it, signum) -> None:
@pytest.mark.skipif(not posix, reason="POSIX specific")
@background_process_param
-async def test_wait_reapable_fails(background_process) -> None:
+async def test_wait_reapable_fails(background_process: BackgroundProcessType) -> None:
old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
try:
# With SIGCHLD disabled, the wait() syscall will wait for the
@@ -496,7 +516,7 @@ def test_waitid_eintr() -> None:
got_alarm = False
sleeper = subprocess.Popen(["sleep", "3600"])
- def on_alarm(sig, frame) -> None:
+ def on_alarm(sig: object, frame: object) -> None:
nonlocal got_alarm
got_alarm = True
sleeper.kill()
@@ -518,7 +538,7 @@ def on_alarm(sig, frame) -> None:
async def test_custom_deliver_cancel() -> None:
custom_deliver_cancel_called = False
- async def custom_deliver_cancel(proc) -> None:
+ async def custom_deliver_cancel(proc: Process) -> None:
nonlocal custom_deliver_cancel_called
custom_deliver_cancel_called = True
proc.terminate()
@@ -542,7 +562,7 @@ async def custom_deliver_cancel(proc) -> None:
async def test_warn_on_failed_cancel_terminate(monkeypatch: MonkeyPatch) -> None:
original_terminate = Process.terminate
- def broken_terminate(self):
+ def broken_terminate(self: Process) -> NoReturn:
original_terminate(self)
raise OSError("whoops")
@@ -557,7 +577,7 @@ def broken_terminate(self):
@pytest.mark.skipif(not posix, reason="posix only")
async def test_warn_on_cancel_SIGKILL_escalation(
- autojump_clock, monkeypatch: MonkeyPatch
+ autojump_clock: MockClock, monkeypatch: MonkeyPatch
) -> None:
monkeypatch.setattr(Process, "terminate", lambda *args: None)
From 6d19cd2a9ed04a346bc36d12301b13dacd08d1e1 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 18 Oct 2023 15:59:45 +0200
Subject: [PATCH 11/35] type test_threads
---
pyproject.toml | 1 -
trio/_tests/test_threads.py | 89 +++++++++++++++++++++++--------------
2 files changed, 56 insertions(+), 34 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 6d4dc673af..5e3f301d63 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -87,7 +87,6 @@ module = [
"trio/_core/_tests/test_guest_mode",
"trio/_core/_tests/test_ki",
"trio/_core/_tests/test_multierror",
-"trio/_tests/test_threads",
]
disallow_any_decorated = false
disallow_incomplete_defs = false
diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py
index 80b870445d..da75589840 100644
--- a/trio/_tests/test_threads.py
+++ b/trio/_tests/test_threads.py
@@ -8,7 +8,19 @@
import time
import weakref
from functools import partial
-from typing import TYPE_CHECKING, Callable, Optional
+from typing import (
+ TYPE_CHECKING,
+ AsyncGenerator,
+ Awaitable,
+ Callable,
+ List,
+ NoReturn,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
import pytest
import sniffio
@@ -16,7 +28,7 @@
from trio._core import TrioToken, current_trio_token
-from .. import CapacityLimiter, Event, _core, sleep
+from .. import CancelScope, CapacityLimiter, Event, _core, sleep
from .._core._tests.test_ki import ki_self
from .._core._tests.tutil import buggy_pypy_asyncgens
from .._threads import (
@@ -25,14 +37,23 @@
from_thread_run_sync,
to_thread_run_sync,
)
+from ..lowlevel import Task
from ..testing import wait_all_tasks_blocked
+RecordType = List[Tuple[str, Union[threading.Thread, Type[BaseException]]]]
+T = TypeVar("T")
+
async def test_do_in_trio_thread() -> None:
trio_thread = threading.current_thread()
- async def check_case(do_in_trio_thread, fn, expected, trio_token=None) -> None:
- record: list[tuple[str, threading.Thread | type[BaseException]]] = []
+ async def check_case(
+ do_in_trio_thread: Callable[..., threading.Thread],
+ fn: Callable[..., T | Awaitable[T]],
+ expected: tuple[str, T],
+ trio_token: TrioToken | None = None,
+ ) -> None:
+ record: RecordType = []
def threadfn() -> None:
try:
@@ -52,21 +73,21 @@ def threadfn() -> None:
token = _core.current_trio_token()
- def f1(record) -> int:
+ def f1(record: RecordType) -> int:
assert not _core.currently_ki_protected()
record.append(("f", threading.current_thread()))
return 2
await check_case(from_thread_run_sync, f1, ("got", 2), trio_token=token)
- def f2(record):
+ def f2(record: RecordType) -> NoReturn:
assert not _core.currently_ki_protected()
record.append(("f", threading.current_thread()))
raise ValueError
await check_case(from_thread_run_sync, f2, ("error", ValueError), trio_token=token)
- async def f3(record) -> int:
+ async def f3(record: RecordType) -> int:
assert not _core.currently_ki_protected()
await _core.checkpoint()
record.append(("f", threading.current_thread()))
@@ -74,7 +95,7 @@ async def f3(record) -> int:
await check_case(from_thread_run, f3, ("got", 3), trio_token=token)
- async def f4(record):
+ async def f4(record: RecordType) -> NoReturn:
assert not _core.currently_ki_protected()
await _core.checkpoint()
record.append(("f", threading.current_thread()))
@@ -151,13 +172,13 @@ async def trio_fn() -> None:
ev.set()
await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
- def thread_fn(token) -> None:
+ def thread_fn(token: TrioToken) -> None:
try:
from_thread_run(trio_fn, trio_token=token)
except _core.Cancelled:
record.append("cancelled")
- async def main():
+ async def main() -> threading.Thread:
token = _core.current_trio_token()
thread = threading.Thread(target=thread_fn, args=(token,))
thread.start()
@@ -284,14 +305,14 @@ async def test_has_pthread_setname_np() -> None:
async def test_run_in_worker_thread() -> None:
trio_thread = threading.current_thread()
- def f(x):
+ def f(x: T) -> tuple[T, threading.Thread]:
return (x, threading.current_thread())
x, child_thread = await to_thread_run_sync(f, 1)
assert x == 1
assert child_thread != trio_thread
- def g():
+ def g() -> NoReturn:
raise ValueError(threading.current_thread())
with pytest.raises(ValueError) as excinfo:
@@ -303,13 +324,13 @@ def g():
async def test_run_in_worker_thread_cancellation() -> None:
register: list[str | None] = [None]
- def f(q) -> None:
+ def f(q: stdlib_queue.Queue[str]) -> None:
# Make the thread block for a controlled amount of time
register[0] = "blocking"
q.get()
register[0] = "finished"
- async def child(q, cancellable):
+ async def child(q: stdlib_queue.Queue[None], cancellable: bool) -> None:
record.append("start")
try:
return await to_thread_run_sync(f, q, cancellable=cancellable)
@@ -400,7 +421,9 @@ async def child() -> None:
@pytest.mark.parametrize("MAX", [3, 5, 10])
@pytest.mark.parametrize("cancel", [False, True])
@pytest.mark.parametrize("use_default_limiter", [False, True])
-async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter) -> None:
+async def test_run_in_worker_thread_limiter(
+ MAX: int, cancel: bool, use_default_limiter: bool
+) -> None:
# This test is a bit tricky. The goal is to make sure that if we set
# limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever
# running at a time, even if there are more concurrent calls to
@@ -444,7 +467,7 @@ class state:
token = _core.current_trio_token()
- def thread_fn(cancel_scope) -> None:
+ def thread_fn(cancel_scope: CancelScope) -> None:
print("thread_fn start")
from_thread_run_sync(cancel_scope.cancel, trio_token=token)
with lock:
@@ -460,7 +483,7 @@ def thread_fn(cancel_scope) -> None:
state.running -= 1
print("thread_fn exiting")
- async def run_thread(event) -> None:
+ async def run_thread(event: Event) -> None:
with _core.CancelScope() as cancel_scope:
await to_thread_run_sync(
thread_fn, cancel_scope, limiter=limiter_arg, cancellable=cancel
@@ -515,11 +538,11 @@ async def test_run_in_worker_thread_custom_limiter() -> None:
record = []
class CustomLimiter:
- async def acquire_on_behalf_of(self, borrower) -> None:
+ async def acquire_on_behalf_of(self, borrower: Task) -> None:
record.append("acquire")
self._borrower = borrower
- def release_on_behalf_of(self, borrower) -> None:
+ def release_on_behalf_of(self, borrower: Task) -> None:
record.append("release")
assert borrower == self._borrower
@@ -533,10 +556,10 @@ async def test_run_in_worker_thread_limiter_error() -> None:
record = []
class BadCapacityLimiter:
- async def acquire_on_behalf_of(self, borrower) -> None:
+ async def acquire_on_behalf_of(self, borrower: Task) -> None:
record.append("acquire")
- def release_on_behalf_of(self, borrower):
+ def release_on_behalf_of(self, borrower: Task) -> NoReturn:
record.append("release")
raise ValueError
@@ -559,7 +582,7 @@ def release_on_behalf_of(self, borrower):
async def test_run_in_worker_thread_fail_to_spawn(monkeypatch: MonkeyPatch) -> None:
# Test the unlikely but possible case where trying to spawn a thread fails
- def bad_start(self, *args):
+ def bad_start(self: object, *args: object) -> NoReturn:
raise RuntimeError("the engines canna take it captain")
monkeypatch.setattr(_core._thread_cache.ThreadCache, "start_thread_soon", bad_start)
@@ -578,7 +601,7 @@ def bad_start(self, *args):
async def test_trio_to_thread_run_sync_token() -> None:
# Test that to_thread_run_sync automatically injects the current trio token
# into a spawned thread
- def thread_fn():
+ def thread_fn() -> TrioToken:
callee_token = from_thread_run_sync(_core.current_trio_token)
return callee_token
@@ -605,7 +628,7 @@ async def test_trio_to_thread_run_sync_contextvars() -> None:
trio_thread = threading.current_thread()
trio_test_contextvar.set("main")
- def f():
+ def f() -> tuple[str, threading.Thread]:
value = trio_test_contextvar.get()
with pytest.raises(sniffio.AsyncLibraryNotFoundError):
sniffio.current_async_library()
@@ -615,7 +638,7 @@ def f():
assert value == "main"
assert child_thread != trio_thread
- def g():
+ def g() -> tuple[str, str, threading.Thread]:
parent_value = trio_test_contextvar.get()
trio_test_contextvar.set("worker")
inner_value = trio_test_contextvar.get()
@@ -641,7 +664,7 @@ def g():
async def test_trio_from_thread_run_sync() -> None:
# Test that to_thread_run_sync correctly "hands off" the trio token to
# trio.from_thread.run_sync()
- def thread_fn_1():
+ def thread_fn_1() -> float:
trio_time = from_thread_run_sync(_core.current_time)
return trio_time
@@ -686,7 +709,7 @@ def sync_fn() -> None: # pragma: no cover
async def test_trio_from_thread_token() -> None:
# Test that to_thread_run_sync and spawned trio.from_thread.run_sync()
# share the same Trio token
- def thread_fn():
+ def thread_fn() -> TrioToken:
callee_token = from_thread_run_sync(_core.current_trio_token)
return callee_token
@@ -698,7 +721,7 @@ def thread_fn():
async def test_trio_from_thread_token_kwarg() -> None:
# Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can
# use an explicitly defined token
- def thread_fn(token):
+ def thread_fn(token: TrioToken) -> TrioToken:
callee_token = from_thread_run_sync(_core.current_trio_token, trio_token=token)
return callee_token
@@ -718,14 +741,14 @@ async def test_from_thread_no_token() -> None:
async def test_trio_from_thread_run_sync_contextvars() -> None:
trio_test_contextvar.set("main")
- def thread_fn():
+ def thread_fn() -> tuple[str, str, str, str, str]:
thread_parent_value = trio_test_contextvar.get()
trio_test_contextvar.set("worker")
thread_current_value = trio_test_contextvar.get()
with pytest.raises(sniffio.AsyncLibraryNotFoundError):
sniffio.current_async_library()
- def back_in_main():
+ def back_in_main() -> tuple[str, str]:
back_parent_value = trio_test_contextvar.get()
trio_test_contextvar.set("back_in_main")
back_current_value = trio_test_contextvar.get()
@@ -761,14 +784,14 @@ def back_in_main():
async def test_trio_from_thread_run_contextvars() -> None:
trio_test_contextvar.set("main")
- def thread_fn():
+ def thread_fn() -> tuple[str, str, str, str, str]:
thread_parent_value = trio_test_contextvar.get()
trio_test_contextvar.set("worker")
thread_current_value = trio_test_contextvar.get()
with pytest.raises(sniffio.AsyncLibraryNotFoundError):
sniffio.current_async_library()
- async def async_back_in_main():
+ async def async_back_in_main() -> tuple[str, str]:
back_parent_value = trio_test_contextvar.get()
trio_test_contextvar.set("back_in_main")
back_current_value = trio_test_contextvar.get()
@@ -823,7 +846,7 @@ def test_from_thread_run_during_shutdown() -> None:
save = []
record = []
- async def agen():
+ async def agen() -> AsyncGenerator[None, None]:
try:
yield
finally:
From e8f01124e2d9326b33ba243b5c4bee1f5f290019 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 18 Oct 2023 16:01:54 +0200
Subject: [PATCH 12/35] finish typing test_guest_mode
---
pyproject.toml | 1 -
trio/_core/_tests/test_guest_mode.py | 12 ++++++++++--
2 files changed, 10 insertions(+), 3 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 5e3f301d63..9edd6d808b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -84,7 +84,6 @@ disallow_untyped_calls = false
[[tool.mypy.overrides]]
module = [
# tests
-"trio/_core/_tests/test_guest_mode",
"trio/_core/_tests/test_ki",
"trio/_core/_tests/test_multierror",
]
diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py
index 286e8657a1..3358f63a46 100644
--- a/trio/_core/_tests/test_guest_mode.py
+++ b/trio/_core/_tests/test_guest_mode.py
@@ -12,7 +12,15 @@
import warnings
from functools import partial
from math import inf
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, NoReturn, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AsyncGenerator,
+ Awaitable,
+ Callable,
+ NoReturn,
+ TypeVar,
+)
import pytest
from outcome import Outcome
@@ -613,7 +621,7 @@ def test_guest_mode_asyncgens() -> None:
record = set()
- async def agen(label: str): # TODO: some asyncgenerator thing
+ async def agen(label: str) -> AsyncGenerator[int, None]:
assert sniffio.current_async_library() == label
try:
yield 1
From 7ef03d9d1edf35b5242e007a042272edec86e3d4 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 18 Oct 2023 16:07:02 +0200
Subject: [PATCH 13/35] type test_ki
---
pyproject.toml | 2 +-
trio/_core/_tests/test_ki.py | 26 +++++++++++++-------------
2 files changed, 14 insertions(+), 14 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 9edd6d808b..949fde14fd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -84,7 +84,7 @@ disallow_untyped_calls = false
[[tool.mypy.overrides]]
module = [
# tests
-"trio/_core/_tests/test_ki",
+#"trio/_core/_tests/test_ki",
"trio/_core/_tests/test_multierror",
]
disallow_any_decorated = false
diff --git a/trio/_core/_tests/test_ki.py b/trio/_core/_tests/test_ki.py
index 60a6b17336..bc4a0192af 100644
--- a/trio/_core/_tests/test_ki.py
+++ b/trio/_core/_tests/test_ki.py
@@ -4,7 +4,7 @@
import inspect
import signal
import threading
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator
import outcome
import pytest
@@ -77,7 +77,7 @@ async def aunprotected() -> None:
nursery.start_soon(aunprotected)
@_core.enable_ki_protection
- def gen_protected():
+ def gen_protected() -> Iterator[None]:
assert _core.currently_ki_protected()
yield
@@ -85,7 +85,7 @@ def gen_protected():
pass
@_core.disable_ki_protection
- def gen_unprotected():
+ def gen_unprotected() -> Iterator[None]:
assert not _core.currently_ki_protected()
yield
@@ -111,7 +111,7 @@ async def protected() -> None:
async def unprotected() -> None:
await child(False)
- async def child(expected) -> None:
+ async def child(expected: bool) -> None:
import traceback
traceback.print_stack()
@@ -126,10 +126,10 @@ async def child(expected) -> None:
# This also used to be broken due to
# https://bugs.python.org/issue29590
-async def test_generator_based_context_manager_throw():
+async def test_generator_based_context_manager_throw() -> None:
@contextlib.contextmanager
@_core.enable_ki_protection
- def protected_manager():
+ def protected_manager() -> Iterator[None]:
assert _core.currently_ki_protected()
try:
yield
@@ -193,7 +193,7 @@ async def agen_unprotected2() -> None:
async def test_native_agen_protection() -> None:
# Native async generators
@_core.enable_ki_protection
- async def agen_protected():
+ async def agen_protected() -> AsyncIterator[None]:
assert _core.currently_ki_protected()
try:
yield
@@ -201,7 +201,7 @@ async def agen_protected():
assert _core.currently_ki_protected()
@_core.disable_ki_protection
- async def agen_unprotected():
+ async def agen_unprotected() -> AsyncIterator[None]:
assert not _core.currently_ki_protected()
try:
yield
@@ -212,7 +212,7 @@ async def agen_unprotected():
await _check_agen(agen_unprotected)
-async def _check_agen(agen_fn):
+async def _check_agen(agen_fn: Callable[[], AsyncIterator[None]]) -> None:
async for _ in agen_fn():
assert not _core.currently_ki_protected()
@@ -235,7 +235,7 @@ def test_ki_disabled_out_of_context() -> None:
def test_ki_disabled_in_del() -> None:
- def nestedfunction():
+ def nestedfunction() -> bool:
return _core.currently_ki_protected()
def __del__() -> None:
@@ -254,14 +254,14 @@ def outerfunction() -> None:
def test_ki_protection_works() -> None:
- async def sleeper(name: str, record) -> None:
+ async def sleeper(name: str, record: set[str]) -> None:
try:
while True:
await _core.checkpoint()
except _core.Cancelled:
record.add(name + " ok")
- async def raiser(name: str, record):
+ async def raiser(name: str, record: set[str]) -> None:
try:
# os.kill runs signal handlers before returning, so we don't need
# to worry that the handler will be delayed
@@ -475,7 +475,7 @@ def test_ki_is_good_neighbor() -> None:
try:
orig = signal.getsignal(signal.SIGINT)
- def my_handler(signum, frame) -> None: # pragma: no cover
+ def my_handler(signum: object, frame: object) -> None: # pragma: no cover
pass
async def main() -> None:
From c7ff2e14ebd69b0b8d951171da8a9b5a6aea9e50 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 18 Oct 2023 16:24:15 +0200
Subject: [PATCH 14/35] type test_multierror
---
pyproject.toml | 15 ++------
trio/_core/_tests/test_multierror.py | 54 +++++++++++++++-------------
trio/_tests/test_subprocess.py | 1 -
3 files changed, 32 insertions(+), 38 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 949fde14fd..75849ec9bd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -70,26 +70,15 @@ warn_return_any = true
# Avoid subtle backsliding
disallow_any_decorated = true
disallow_any_generics = true
-disallow_any_unimported = false # Enable once Outcome has stubs.
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
+check_untyped_defs = true
# Enable once other problems are dealt with
-check_untyped_defs = true
disallow_untyped_calls = false
-
-# files not yet fully typed
-[[tool.mypy.overrides]]
-module = [
-# tests
-#"trio/_core/_tests/test_ki",
-"trio/_core/_tests/test_multierror",
-]
-disallow_any_decorated = false
-disallow_incomplete_defs = false
-disallow_untyped_defs = false
+disallow_any_unimported = false # Enable once Outcome has stubs.
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
diff --git a/trio/_core/_tests/test_multierror.py b/trio/_core/_tests/test_multierror.py
index 60ac41ae14..f2682cdf9f 100644
--- a/trio/_core/_tests/test_multierror.py
+++ b/trio/_core/_tests/test_multierror.py
@@ -9,7 +9,8 @@
import warnings
from pathlib import Path
from traceback import extract_tb, print_exception
-from typing import List
+from types import TracebackType
+from typing import Callable, List, NoReturn
import pytest
@@ -35,42 +36,43 @@ def __eq__(self, other: object) -> bool:
return self.code == other.code
-async def raise_nothashable(code):
+async def raise_nothashable(code: int) -> NoReturn:
raise NotHashableException(code)
-def raiser1() -> None:
+def raiser1() -> NoReturn:
raiser1_2()
-def raiser1_2() -> None:
+def raiser1_2() -> NoReturn:
raiser1_3()
-def raiser1_3():
+def raiser1_3() -> NoReturn:
raise ValueError("raiser1_string")
-def raiser2() -> None:
+def raiser2() -> NoReturn:
raiser2_2()
-def raiser2_2():
+def raiser2_2() -> NoReturn:
raise KeyError("raiser2_string")
-def raiser3():
+def raiser3() -> NoReturn:
raise NameError
-def get_exc(raiser):
+def get_exc(raiser: Callable[[], NoReturn]) -> BaseException:
try:
raiser()
except Exception as exc:
return exc
+ raise AssertionError("raiser should always raise")
-def get_tb(raiser):
+def get_tb(raiser: Callable[[], NoReturn]) -> TracebackType | None:
return get_exc(raiser).__traceback__
@@ -141,7 +143,7 @@ async def test_MultiErrorNotHashable() -> None:
def test_MultiError_filter_NotHashable() -> None:
excs = MultiError([NotHashableException(42), ValueError()])
- def handle_ValueError(exc):
+ def handle_ValueError(exc: BaseException) -> BaseException | None:
if isinstance(exc, ValueError):
return None
else:
@@ -153,7 +155,7 @@ def handle_ValueError(exc):
assert isinstance(filtered_excs, NotHashableException)
-def make_tree():
+def make_tree() -> MultiError:
# Returns an object like:
# MultiError([
# MultiError([
@@ -174,7 +176,9 @@ def make_tree():
return MultiError([m12, exc3])
-def assert_tree_eq(m1, m2) -> None:
+def assert_tree_eq(
+ m1: BaseException | MultiError | None, m2: BaseException | MultiError | None
+) -> None:
if m1 is None or m2 is None:
assert m1 is m2
return
@@ -183,13 +187,14 @@ def assert_tree_eq(m1, m2) -> None:
assert_tree_eq(m1.__cause__, m2.__cause__)
assert_tree_eq(m1.__context__, m2.__context__)
if isinstance(m1, MultiError):
+ assert isinstance(m2, MultiError)
assert len(m1.exceptions) == len(m2.exceptions)
for e1, e2 in zip(m1.exceptions, m2.exceptions):
assert_tree_eq(e1, e2)
-def test_MultiError_filter():
- def null_handler(exc):
+def test_MultiError_filter() -> None:
+ def null_handler(exc: BaseException) -> BaseException:
return exc
m = make_tree()
@@ -209,7 +214,7 @@ def null_handler(exc):
assert MultiError.filter(null_handler, m) is m
assert_tree_eq(m, make_tree())
- def simple_filter(exc):
+ def simple_filter(exc: BaseException) -> BaseException | None:
if isinstance(exc, ValueError):
return None
if isinstance(exc, KeyError):
@@ -233,6 +238,7 @@ def simple_filter(exc):
# traceback on its parent MultiError
orig = make_tree()
# make sure we have the right path
+ assert isinstance(orig.exceptions[0], MultiError)
assert isinstance(orig.exceptions[0].exceptions[1], KeyError)
# get original traceback summary
orig_extracted = (
@@ -241,7 +247,7 @@ def simple_filter(exc):
+ extract_tb(orig.exceptions[0].exceptions[1].__traceback__)
)
- def p(exc) -> None:
+ def p(exc: BaseException) -> None:
print_exception(type(exc), exc, exc.__traceback__)
p(orig)
@@ -254,7 +260,7 @@ def p(exc) -> None:
assert orig_extracted == new_extracted
# check preserving partial tree
- def filter_NameError(exc):
+ def filter_NameError(exc: BaseException) -> BaseException | None:
if isinstance(exc, NameError):
return None
return exc
@@ -266,17 +272,17 @@ def filter_NameError(exc):
assert new_m is m.exceptions[0]
# check fully handling everything
- def filter_all(exc):
+ def filter_all(exc: BaseException) -> None:
return None
with pytest.warns(TrioDeprecationWarning):
assert MultiError.filter(filter_all, make_tree()) is None
-def test_MultiError_catch():
+def test_MultiError_catch() -> None:
# No exception to catch
- def noop(_) -> None:
+ def noop(_: object) -> None:
pass # pragma: no cover
with pytest.warns(TrioDeprecationWarning), MultiError.catch(noop):
@@ -363,16 +369,16 @@ def catch_RuntimeError(exc):
@pytest.mark.skipif(
sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC"
)
-def test_MultiError_catch_doesnt_create_cyclic_garbage():
+def test_MultiError_catch_doesnt_create_cyclic_garbage() -> None:
# https://github.com/python-trio/trio/pull/2063
gc.collect()
old_flags = gc.get_debug()
- def make_multi():
+ def make_multi() -> NoReturn:
# make_tree creates cycles itself, so a simple
raise MultiError([get_exc(raiser1), get_exc(raiser2)])
- def simple_filter(exc):
+ def simple_filter(exc: BaseException) -> Exception | RuntimeError:
if isinstance(exc, ValueError):
return Exception()
if isinstance(exc, KeyError):
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index 3da8e48d63..e4713b3950 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -215,7 +215,6 @@ async def feed_input() -> None:
await proc.stdin.aclose()
async def check_output(stream: Stream, expected: bytes) -> None:
- assert type(stream) is None
seen = bytearray()
async for chunk in stream:
seen += chunk
From e873df1eb3dc8a670852d9409aea16abc5606c11 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 18 Oct 2023 17:35:30 +0200
Subject: [PATCH 15/35] fix some type errors I missed outside of _tests/ -
testing/_fake_net.py and missing type parameters for SSLStream
---
trio/_highlevel_ssl_helpers.py | 17 ++-
trio/_ssl.py | 6 +-
trio/_tests/test_highlevel_ssl_helpers.py | 4 +-
trio/_tests/test_ssl.py | 2 +-
trio/testing/_fake_net.py | 133 +++++++++++++++-------
5 files changed, 112 insertions(+), 50 deletions(-)
diff --git a/trio/_highlevel_ssl_helpers.py b/trio/_highlevel_ssl_helpers.py
index 1647f373c2..c03919d6c0 100644
--- a/trio/_highlevel_ssl_helpers.py
+++ b/trio/_highlevel_ssl_helpers.py
@@ -2,11 +2,16 @@
import ssl
from collections.abc import Awaitable, Callable
-from typing import NoReturn
+from typing import NoReturn, TypeVar
import trio
from ._highlevel_open_tcp_stream import DEFAULT_DELAY
+from ._highlevel_socket import SocketStream
+from .abc import Stream
+
+T = TypeVar("T")
+T_Stream = TypeVar("T_Stream", bound=Stream)
# It might have been nice to take a ssl_protocols= argument here to set up
@@ -25,7 +30,7 @@ async def open_ssl_over_tcp_stream(
https_compatible: bool = False,
ssl_context: ssl.SSLContext | None = None,
happy_eyeballs_delay: float | None = DEFAULT_DELAY,
-) -> trio.SSLStream:
+) -> trio.SSLStream[SocketStream]:
"""Make a TLS-encrypted Connection to the given host and port over TCP.
This is a convenience wrapper that calls :func:`open_tcp_stream` and
@@ -73,7 +78,7 @@ async def open_ssl_over_tcp_listeners(
host: str | bytes | None = None,
https_compatible: bool = False,
backlog: int | float | None = None,
-) -> list[trio.SSLListener]:
+) -> list[trio.SSLListener[SocketStream]]:
"""Start listening for SSL/TLS-encrypted TCP connections to the given port.
Args:
@@ -95,7 +100,7 @@ async def open_ssl_over_tcp_listeners(
async def serve_ssl_over_tcp(
- handler: Callable[[trio.SSLStream], Awaitable[object]],
+ handler: Callable[[trio.SSLStream[T_Stream]], Awaitable[object]],
port: int,
ssl_context: ssl.SSLContext,
*,
@@ -103,7 +108,9 @@ async def serve_ssl_over_tcp(
https_compatible: bool = False,
backlog: int | float | None = None,
handler_nursery: trio.Nursery | None = None,
- task_status: trio.TaskStatus[list[trio.SSLListener]] = trio.TASK_STATUS_IGNORED,
+ task_status: trio.TaskStatus[
+ list[trio.SSLListener[T_Stream]]
+ ] = trio.TASK_STATUS_IGNORED,
) -> NoReturn:
"""Listen for incoming TCP connections, and for each one start a task
running ``handler(stream)``.
diff --git a/trio/_ssl.py b/trio/_ssl.py
index 073d7e3812..77d3b80140 100644
--- a/trio/_ssl.py
+++ b/trio/_ssl.py
@@ -893,7 +893,7 @@ async def wait_send_all_might_not_block(self) -> None:
@final
-class SSLListener(Listener[SSLStream]):
+class SSLListener(Listener[SSLStream[T_Stream]]):
"""A :class:`~trio.abc.Listener` for SSL/TLS-encrypted servers.
:class:`SSLListener` wraps around another Listener, and converts
@@ -917,7 +917,7 @@ class SSLListener(Listener[SSLStream]):
def __init__(
self,
- transport_listener: Listener[Stream],
+ transport_listener: Listener[T_Stream],
ssl_context: _stdlib_ssl.SSLContext,
*,
https_compatible: bool = False,
@@ -926,7 +926,7 @@ def __init__(
self._ssl_context = ssl_context
self._https_compatible = https_compatible
- async def accept(self) -> SSLStream:
+ async def accept(self) -> SSLStream[T_Stream]:
"""Accept the next connection and wrap it in an :class:`SSLStream`.
See :meth:`trio.abc.Listener.accept` for details.
diff --git a/trio/_tests/test_highlevel_ssl_helpers.py b/trio/_tests/test_highlevel_ssl_helpers.py
index 5a06279204..c1b0febbd5 100644
--- a/trio/_tests/test_highlevel_ssl_helpers.py
+++ b/trio/_tests/test_highlevel_ssl_helpers.py
@@ -74,13 +74,13 @@ async def test_open_ssl_over_tcp_stream_and_everything_else(
async with trio.open_nursery() as nursery:
# TODO: the types are *very* funky here, this seems like an error in some signature
# unless this is doing stuff we don't want/expect end users to do
- res: list[SSLListener] = await nursery.start(
+ res: list[SSLListener[SocketListener]] = await nursery.start(
partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1")
)
(listener,) = res
async with listener:
# listener.transport_listener is of type Listener[Stream]
- tp_listener: SocketListener = listener.transport_listener # type: ignore[assignment]
+ tp_listener: SocketListener = listener.transport_listener
sockaddr = tp_listener.socket.getsockname()
hostname_resolver = FakeHostnameResolver(sockaddr)
diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py
index 5dece46e54..58e069f239 100644
--- a/trio/_tests/test_ssl.py
+++ b/trio/_tests/test_ssl.py
@@ -1295,7 +1295,7 @@ async def test_getpeercert(client_ctx: SSLContext) -> None:
async def test_SSLListener(client_ctx: SSLContext) -> None:
async def setup(
**kwargs: Any,
- ) -> tuple[tsocket.SocketType, SSLListener, SSLStream[SocketStream]]:
+ ) -> tuple[tsocket.SocketType, SSLListener[SocketStream], SSLStream[SocketStream]]:
listen_sock = tsocket.socket()
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(1)
diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py
index 74ce32d37f..2b1d2c8b34 100644
--- a/trio/testing/_fake_net.py
+++ b/trio/testing/_fake_net.py
@@ -12,9 +12,19 @@
import errno
import ipaddress
import os
-from typing import TYPE_CHECKING, Optional, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Iterable,
+ NoReturn,
+ Optional,
+ TypeVar,
+ Union,
+ overload,
+)
import attr
+from typing_extensions import Buffer, Self
import trio
from trio._util import NoPublicConstructor, final
@@ -54,11 +64,11 @@ def _localhost_ip_for(family: int) -> IPAddress:
assert False
-def _fake_err(code):
+def _fake_err(code: int) -> NoReturn:
raise OSError(code, os.strerror(code))
-def _scatter(data, buffers):
+def _scatter(data: bytes, buffers: Iterable[bytes]) -> int:
written = 0
for buf in buffers:
next_piece = data[written : written + len(buf)]
@@ -70,19 +80,27 @@ def _scatter(data, buffers):
return written
+T_UDPEndpoint = TypeVar("T_UDPEndpoint", bound="UDPEndpoint")
+
+
@attr.frozen
class UDPEndpoint:
ip: IPAddress
port: int
- def as_python_sockaddr(self):
- sockaddr = (self.ip.compressed, self.port)
+ def as_python_sockaddr(self) -> tuple[str, int] | tuple[str, int, int, int]:
+ sockaddr: tuple[str, int] | tuple[str, int, int, int] = (
+ self.ip.compressed,
+ self.port,
+ )
if isinstance(self.ip, ipaddress.IPv6Address):
- sockaddr += (0, 0)
+ sockaddr += (0, 0) # type: ignore[assignment]
return sockaddr
@classmethod
- def from_python_sockaddr(cls, sockaddr):
+ def from_python_sockaddr(
+ cls: type[T_UDPEndpoint], sockaddr: tuple[str, int] | tuple[str, int, int, int]
+ ) -> T_UDPEndpoint:
ip, port = sockaddr[:2]
return cls(ip=ipaddress.ip_address(ip), port=port)
@@ -98,7 +116,7 @@ class UDPPacket:
destination: UDPEndpoint
payload: bytes = attr.ib(repr=lambda p: p.hex())
- def reply(self, payload):
+ def reply(self, payload: bytes) -> UDPPacket:
return UDPPacket(
source=self.destination, destination=self.source, payload=payload
)
@@ -162,13 +180,13 @@ def enable(self) -> None:
trio.socket.set_custom_socket_factory(FakeSocketFactory(self))
trio.socket.set_custom_hostname_resolver(FakeHostnameResolver(self))
- def send_packet(self, packet) -> None:
+ def send_packet(self, packet: UDPPacket) -> None:
if self.route_packet is None:
self.deliver_packet(packet)
else:
self.route_packet(packet)
- def deliver_packet(self, packet) -> None:
+ def deliver_packet(self, packet: UDPPacket) -> None:
binding = UDPBinding(local=packet.destination)
if binding in self._bound:
self._bound[binding]._deliver_packet(packet)
@@ -219,11 +237,11 @@ def family(self) -> AddressFamily:
def proto(self) -> int:
return self._proto
- def _check_closed(self):
+ def _check_closed(self) -> None:
if self._closed:
_fake_err(errno.EBADF)
- def close(self):
+ def close(self) -> None:
# breakpoint()
if self._closed:
return
@@ -232,8 +250,10 @@ def close(self):
del self._fake_net._bound[self._binding]
self._packet_receiver.close()
- async def _resolve_address_nocp(self, address, *, local):
- return await trio._socket._resolve_address_nocp(
+ async def _resolve_address_nocp(
+ self, address: object, *, local: bool
+ ) -> tuple[str, int]:
+ return await trio._socket._resolve_address_nocp( # type: ignore[no-any-return]
self.type,
self.family,
self.proto,
@@ -253,7 +273,7 @@ def _deliver_packet(self, packet: UDPPacket) -> None:
# Actual IO operation implementations
################################################################
- async def bind(self, addr):
+ async def bind(self, addr: object) -> None:
self._check_closed()
if self._binding is not None:
_fake_err(errno.EINVAL)
@@ -272,10 +292,10 @@ async def bind(self, addr):
self._fake_net._bind(binding, self)
self._binding = binding
- async def connect(self, peer):
+ async def connect(self, peer: object) -> NoReturn:
raise NotImplementedError("FakeNet does not (yet) support connected sockets")
- async def sendmsg(self, *args):
+ async def sendmsg(self, *args: Any) -> int:
self._check_closed()
ancdata = []
flags = 0
@@ -310,6 +330,7 @@ async def sendmsg(self, *args):
payload = b"".join(buffers)
+ assert self._binding is not None
packet = UDPPacket(
source=self._binding.local,
destination=destination,
@@ -320,7 +341,12 @@ async def sendmsg(self, *args):
return len(payload)
- async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
+ async def recvmsg_into(
+ self,
+ buffers: Iterable[Buffer],
+ ancbufsize: int = 0,
+ flags: int = 0,
+ ) -> tuple[int, list[tuple[int, int, bytes]], int, Any]:
if ancbufsize != 0:
raise NotImplementedError("FakeNet doesn't support ancillary data")
if flags != 0:
@@ -328,7 +354,7 @@ async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
self._check_closed()
- ancdata = []
+ ancdata: list[tuple[int, int, bytes]] = []
msg_flags = 0
packet = await self._packet_receiver.receive()
@@ -342,7 +368,7 @@ async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
# Simple state query stuff
################################################################
- def getsockname(self):
+ def getsockname(self) -> tuple[str, int] | tuple[str, int, int, int]:
self._check_closed()
if self._binding is not None:
return self._binding.local.as_python_sockaddr()
@@ -352,31 +378,56 @@ def getsockname(self):
assert self.family == trio.socket.AF_INET6
return ("::", 0)
- def getpeername(self):
+ def getpeername(self) -> None:
self._check_closed()
if self._binding is not None:
if self._binding.remote is not None:
return self._binding.remote.as_python_sockaddr()
_fake_err(errno.ENOTCONN)
- def getsockopt(self, level, item):
+ @overload
+ def getsockopt(self, /, level: int, optname: int) -> int:
+ ...
+
+ @overload
+ def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes:
+ ...
+
+ def getsockopt(
+ self, /, level: int, optname: int, buflen: int | None = None
+ ) -> int | bytes:
self._check_closed()
- raise OSError(f"FakeNet doesn't implement getsockopt({level}, {item})")
+ raise OSError(f"FakeNet doesn't implement getsockopt({level}, {optname})")
+
+ @overload
+ def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None:
+ ...
- def setsockopt(self, level, item, value):
+ @overload
+ def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None:
+ ...
+
+ def setsockopt(
+ self,
+ /,
+ level: int,
+ optname: int,
+ value: int | Buffer | None,
+ optlen: int | None = None,
+ ) -> None:
self._check_closed()
- if (level, item) == (trio.socket.IPPROTO_IPV6, trio.socket.IPV6_V6ONLY):
+ if (level, optname) == (trio.socket.IPPROTO_IPV6, trio.socket.IPV6_V6ONLY):
if not value:
raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True")
- raise OSError(f"FakeNet doesn't implement setsockopt({level}, {item}, ...)")
+ raise OSError(f"FakeNet doesn't implement setsockopt({level}, {optname}, ...)")
################################################################
# Various boilerplate and trivial stubs
################################################################
- def __enter__(self):
+ def __enter__(self) -> Self:
return self
def __exit__(
@@ -387,10 +438,10 @@ def __exit__(
) -> None:
self.close()
- async def send(self, data, flags=0):
+ async def send(self, data: Buffer, flags: int = 0) -> int:
return await self.sendto(data, flags, None)
- async def sendto(self, *args):
+ async def sendto(self, *args: Any) -> int:
if len(args) == 2:
data, address = args
flags = 0
@@ -400,19 +451,21 @@ async def sendto(self, *args):
raise TypeError("wrong number of arguments")
return await self.sendmsg([data], [], flags, address)
- async def recv(self, bufsize, flags=0):
+ async def recv(self, bufsize: int, flags: int = 0) -> bytes:
data, address = await self.recvfrom(bufsize, flags)
return data
- async def recv_into(self, buf, nbytes=0, flags=0):
+ async def recv_into(self, buf: Buffer, nbytes: int = 0, flags: int = 0) -> int:
got_bytes, address = await self.recvfrom_into(buf, nbytes, flags)
return got_bytes
- async def recvfrom(self, bufsize, flags=0):
+ async def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[bytes, Any]:
data, ancdata, msg_flags, address = await self.recvmsg(bufsize, flags)
return data, address
- async def recvfrom_into(self, buf, nbytes=0, flags=0):
+ async def recvfrom_into(
+ self, buf: Buffer, nbytes: int = 0, flags: int = 0
+ ) -> tuple[int, Any]:
if nbytes != 0 and nbytes != len(buf):
raise NotImplementedError("partial recvfrom_into")
got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into(
@@ -420,25 +473,27 @@ async def recvfrom_into(self, buf, nbytes=0, flags=0):
)
return got_nbytes, address
- async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
+ async def recvmsg(
+ self, bufsize: int, ancbufsize: int = 0, flags: int = 0
+ ) -> tuple[bytes, list[tuple[int, int, bytes]], int, Any]:
buf = bytearray(bufsize)
got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into(
[buf], ancbufsize, flags
)
return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address)
- def fileno(self):
+ def fileno(self) -> int:
raise NotImplementedError("can't get fileno() for FakeNet sockets")
- def detach(self):
+ def detach(self) -> int:
raise NotImplementedError("can't detach() a FakeNet socket")
- def get_inheritable(self):
+ def get_inheritable(self) -> bool:
return False
- def set_inheritable(self, inheritable):
+ def set_inheritable(self, inheritable: bool) -> None:
if inheritable:
raise NotImplementedError("FakeNet can't make inheritable sockets")
- def share(self, process_id):
+ def share(self, process_id: int) -> bytes:
raise NotImplementedError("FakeNet can't share sockets")
From ef91fb7d700fec6deb595b1852ab6b00a28419ab Mon Sep 17 00:00:00 2001
From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com>
Date: Wed, 18 Oct 2023 18:19:48 -0500
Subject: [PATCH 16/35] Fix runtime instantation of objects that don't define
`__class_getitem__`
---
trio/_core/_tests/test_guest_mode.py | 3 ++-
trio/_core/_tests/test_thread_cache.py | 8 ++++----
trio/_tests/test_threads.py | 13 +++++++------
3 files changed, 13 insertions(+), 11 deletions(-)
diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py
index 3358f63a46..d53048b067 100644
--- a/trio/_core/_tests/test_guest_mode.py
+++ b/trio/_core/_tests/test_guest_mode.py
@@ -20,6 +20,7 @@
Callable,
NoReturn,
TypeVar,
+ cast,
)
import pytest
@@ -458,7 +459,7 @@ async def trio_main() -> str:
print("trio_main!")
to_trio, from_aio = trio.open_memory_channel[int](float("inf"))
- from_trio = asyncio.Queue[int]()
+ from_trio = cast("asyncio.Queue[int]", asyncio.Queue)()
aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio))
diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py
index 2b8d913948..277b6d6bb5 100644
--- a/trio/_core/_tests/test_thread_cache.py
+++ b/trio/_core/_tests/test_thread_cache.py
@@ -4,7 +4,7 @@
import time
from contextlib import contextmanager
from queue import Queue
-from typing import Iterator, NoReturn
+from typing import Iterator, NoReturn, cast
import pytest
from outcome import Outcome
@@ -16,7 +16,7 @@
def test_thread_cache_basics() -> None:
- q = Queue[Outcome]()
+ q = cast("Queue[Outcome]", Queue)()
def fn() -> NoReturn:
raise RuntimeError("hi")
@@ -41,7 +41,7 @@ def __call__(self) -> int:
def __del__(self) -> None:
res[0] = True
- q = Queue[Outcome]()
+ q = cast("Queue[Outcome]", Queue)()
def deliver(outcome: Outcome) -> None:
q.put(outcome)
@@ -64,7 +64,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None:
# Make sure there are a few threads running, so if we weren't LIFO then we
# could grab the wrong one.
- q = Queue[Outcome]()
+ q = cast("Queue[Outcome]", Queue)()
COUNT = 5
for _ in range(COUNT):
start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result))
diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py
index 77a3bf874c..866448dac8 100644
--- a/trio/_tests/test_threads.py
+++ b/trio/_tests/test_threads.py
@@ -20,6 +20,7 @@
Type,
TypeVar,
Union,
+ cast,
)
import pytest
@@ -346,7 +347,7 @@ async def child(q: stdlib_queue.Queue[None], cancellable: bool) -> None:
record.append("exit")
record: list[str] = []
- q = stdlib_queue.Queue[None]()
+ q = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue)()
async with _core.open_nursery() as nursery:
nursery.start_soon(child, q, True)
# Give it a chance to get started. (This is important because
@@ -394,8 +395,8 @@ def test_run_in_worker_thread_abandoned(
) -> None:
monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01)
- q1 = stdlib_queue.Queue[None]()
- q2 = stdlib_queue.Queue[threading.Thread]()
+ q1 = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue)()
+ q2 = cast("stdlib_queue.Queue[threading.Thread]", stdlib_queue.Queue)()
def thread_fn() -> None:
q1.get()
@@ -918,7 +919,7 @@ def get_tid_then_reenter() -> int:
async def test_from_thread_host_cancelled() -> None:
- queue = stdlib_queue.Queue[bool]()
+ queue = cast("stdlib_queue.Queue[bool]", stdlib_queue.Queue)()
def sync_check() -> None:
from_thread_run_sync(cancel_scope.cancel)
@@ -977,7 +978,7 @@ async def async_time_bomb() -> None:
async def test_from_thread_check_cancelled() -> None:
- q = stdlib_queue.Queue[str]()
+ q = cast("stdlib_queue.Queue[str]", stdlib_queue.Queue)()
async def child(cancellable: bool, scope: CancelScope) -> None:
with scope:
@@ -1057,7 +1058,7 @@ def f() -> None: # type: ignore[no-redef] # noqa: F811
async def test_from_thread_check_cancelled_raises_in_foreign_threads() -> None:
with pytest.raises(RuntimeError):
from_thread_check_cancelled()
- q = stdlib_queue.Queue[Outcome]()
+ q = cast("stdlib_queue.Queue[Outcome]", stdlib_queue.Queue)()
_core.start_thread_soon(from_thread_check_cancelled, lambda _: q.put(_))
with pytest.raises(RuntimeError):
q.get(timeout=1).unwrap()
From 97eba76551cc6fe3465f7b45dfb9d15b3a25337f Mon Sep 17 00:00:00 2001
From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com>
Date: Wed, 18 Oct 2023 18:25:36 -0500
Subject: [PATCH 17/35] Missed one of the `Queue`s
---
trio/_core/_tests/test_thread_cache.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py
index 277b6d6bb5..8f9be0d23f 100644
--- a/trio/_core/_tests/test_thread_cache.py
+++ b/trio/_core/_tests/test_thread_cache.py
@@ -96,7 +96,7 @@ def test_idle_threads_exit(monkeypatch: MonkeyPatch) -> None:
# CPU.)
monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001)
- q = Queue[threading.Thread]()
+ q = cast("Queue[threading.Thread]", Queue)()
start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread()))
seen_thread = q.get()
# Since the idle timeout is 0, after sleeping for 1 second, the thread
From f3e5a36b4705430870e3532655181a5aa7e237d7 Mon Sep 17 00:00:00 2001
From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com>
Date: Wed, 18 Oct 2023 18:35:44 -0500
Subject: [PATCH 18/35] Fix not callable issue I just created with the last fix
---
trio/_core/_tests/test_guest_mode.py | 2 +-
trio/_core/_tests/test_thread_cache.py | 8 ++++----
trio/_tests/test_threads.py | 12 ++++++------
3 files changed, 11 insertions(+), 11 deletions(-)
diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py
index d53048b067..bd0ef9ce69 100644
--- a/trio/_core/_tests/test_guest_mode.py
+++ b/trio/_core/_tests/test_guest_mode.py
@@ -459,7 +459,7 @@ async def trio_main() -> str:
print("trio_main!")
to_trio, from_aio = trio.open_memory_channel[int](float("inf"))
- from_trio = cast("asyncio.Queue[int]", asyncio.Queue)()
+ from_trio = cast("asyncio.Queue[int]", asyncio.Queue())
aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio))
diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py
index 8f9be0d23f..e5b4902d8d 100644
--- a/trio/_core/_tests/test_thread_cache.py
+++ b/trio/_core/_tests/test_thread_cache.py
@@ -16,7 +16,7 @@
def test_thread_cache_basics() -> None:
- q = cast("Queue[Outcome]", Queue)()
+ q = cast("Queue[Outcome]", Queue())
def fn() -> NoReturn:
raise RuntimeError("hi")
@@ -41,7 +41,7 @@ def __call__(self) -> int:
def __del__(self) -> None:
res[0] = True
- q = cast("Queue[Outcome]", Queue)()
+ q = cast("Queue[Outcome]", Queue())
def deliver(outcome: Outcome) -> None:
q.put(outcome)
@@ -64,7 +64,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None:
# Make sure there are a few threads running, so if we weren't LIFO then we
# could grab the wrong one.
- q = cast("Queue[Outcome]", Queue)()
+ q = cast("Queue[Outcome]", Queue())
COUNT = 5
for _ in range(COUNT):
start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result))
@@ -96,7 +96,7 @@ def test_idle_threads_exit(monkeypatch: MonkeyPatch) -> None:
# CPU.)
monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001)
- q = cast("Queue[threading.Thread]", Queue)()
+ q = cast("Queue[threading.Thread]", Queue())
start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread()))
seen_thread = q.get()
# Since the idle timeout is 0, after sleeping for 1 second, the thread
diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py
index 866448dac8..4fa1e4bedf 100644
--- a/trio/_tests/test_threads.py
+++ b/trio/_tests/test_threads.py
@@ -347,7 +347,7 @@ async def child(q: stdlib_queue.Queue[None], cancellable: bool) -> None:
record.append("exit")
record: list[str] = []
- q = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue)()
+ q = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue())
async with _core.open_nursery() as nursery:
nursery.start_soon(child, q, True)
# Give it a chance to get started. (This is important because
@@ -395,8 +395,8 @@ def test_run_in_worker_thread_abandoned(
) -> None:
monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01)
- q1 = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue)()
- q2 = cast("stdlib_queue.Queue[threading.Thread]", stdlib_queue.Queue)()
+ q1 = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue())
+ q2 = cast("stdlib_queue.Queue[threading.Thread]", stdlib_queue.Queue())
def thread_fn() -> None:
q1.get()
@@ -919,7 +919,7 @@ def get_tid_then_reenter() -> int:
async def test_from_thread_host_cancelled() -> None:
- queue = cast("stdlib_queue.Queue[bool]", stdlib_queue.Queue)()
+ queue = cast("stdlib_queue.Queue[bool]", stdlib_queue.Queue())
def sync_check() -> None:
from_thread_run_sync(cancel_scope.cancel)
@@ -978,7 +978,7 @@ async def async_time_bomb() -> None:
async def test_from_thread_check_cancelled() -> None:
- q = cast("stdlib_queue.Queue[str]", stdlib_queue.Queue)()
+ q = cast("stdlib_queue.Queue[str]", stdlib_queue.Queue())
async def child(cancellable: bool, scope: CancelScope) -> None:
with scope:
@@ -1058,7 +1058,7 @@ def f() -> None: # type: ignore[no-redef] # noqa: F811
async def test_from_thread_check_cancelled_raises_in_foreign_threads() -> None:
with pytest.raises(RuntimeError):
from_thread_check_cancelled()
- q = cast("stdlib_queue.Queue[Outcome]", stdlib_queue.Queue)()
+ q = cast("stdlib_queue.Queue[Outcome]", stdlib_queue.Queue())
_core.start_thread_soon(from_thread_check_cancelled, lambda _: q.put(_))
with pytest.raises(RuntimeError):
q.get(timeout=1).unwrap()
From 3c508020004ad03b076a256b8c8f2f64fc3966b4 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 19 Oct 2023 16:16:52 +0200
Subject: [PATCH 19/35] fixes after suggestions from teamspen & coolcat <3
---
trio/_core/_local.py | 2 +-
trio/_core/_tests/test_guest_mode.py | 7 +++++--
trio/_ssl.py | 2 +-
trio/_tests/test_file_io.py | 16 +++++++++-------
trio/_tests/test_highlevel_generic.py | 3 ++-
trio/_tests/test_ssl.py | 6 ++++--
trio/_tests/test_subprocess.py | 4 +++-
trio/_tests/test_sync.py | 6 ++++--
trio/_tests/test_threads.py | 11 ++++-------
trio/testing/_fake_net.py | 3 +--
10 files changed, 34 insertions(+), 26 deletions(-)
diff --git a/trio/_core/_local.py b/trio/_core/_local.py
index 27252bc78d..f1cbb3e61f 100644
--- a/trio/_core/_local.py
+++ b/trio/_core/_local.py
@@ -17,7 +17,7 @@ class _NoValue:
@final
-@attr.s(eq=False, hash=False, slots=False)
+@attr.s(eq=False, hash=False, slots=True)
class RunVarToken(Generic[T], metaclass=NoPublicConstructor):
_var: RunVar[T] = attr.ib()
previous_value: T | type[_NoValue] = attr.ib(default=_NoValue)
diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py
index bd0ef9ce69..be9d4cb9c2 100644
--- a/trio/_core/_tests/test_guest_mode.py
+++ b/trio/_core/_tests/test_guest_mode.py
@@ -26,15 +26,18 @@
import pytest
from outcome import Outcome
from pytest import MonkeyPatch, WarningsRecorder
-from typing_extensions import TypeAlias
import trio
import trio.testing
from trio._channel import MemorySendChannel
+from trio.abc import Instrument
from ..._util import signal_raise
from .tutil import buggy_pypy_asyncgens, gc_collect_harder, restore_unraisablehook
+if TYPE_CHECKING:
+ from typing_extensions import TypeAlias
+
T = TypeVar("T")
InHost: TypeAlias = Callable[[object], None]
@@ -346,7 +349,7 @@ async def get_woken_by_host_deadline(watb_cscope: trio.CancelScope) -> None:
# 'sit_in_wait_all_tasks_blocked', we want the test to
# actually end. So in after_io_wait we schedule a second host
# call to tear things down.
- class InstrumentHelper(trio._abc.Instrument):
+ class InstrumentHelper(Instrument):
def __init__(self) -> None:
self.primed = False
diff --git a/trio/_ssl.py b/trio/_ssl.py
index 77d3b80140..57f41c11d2 100644
--- a/trio/_ssl.py
+++ b/trio/_ssl.py
@@ -239,7 +239,7 @@ def done(self) -> bool:
_State = _Enum("_State", ["OK", "BROKEN", "CLOSED"])
-# TODO: variance
+# invariant
T_Stream = TypeVar("T_Stream", bound=Stream)
diff --git a/trio/_tests/test_file_io.py b/trio/_tests/test_file_io.py
index d438a9fb10..5d617a0661 100644
--- a/trio/_tests/test_file_io.py
+++ b/trio/_tests/test_file_io.py
@@ -244,13 +244,15 @@ async def test_aclose_cancelled(path: pathlib.Path) -> None:
assert f.closed
-async def test_detach_rewraps_asynciobase() -> None:
- raw = io.BytesIO()
- buffered = io.BufferedReader(raw) # type: ignore[arg-type] # ????????????
+async def test_detach_rewraps_asynciobase(tmp_path: pathlib.Path) -> None:
+ tmp_file = tmp_path / "filename"
+ tmp_file.touch()
+ with open(tmp_file, mode="rb", buffering=0) as raw:
+ buffered = io.BufferedReader(raw)
- async_file = trio.wrap_file(buffered)
+ async_file = trio.wrap_file(buffered)
- detached = await async_file.detach()
+ detached = await async_file.detach()
- assert isinstance(detached, AsyncIOWrapper)
- assert detached.wrapped is raw
+ assert isinstance(detached, AsyncIOWrapper)
+ assert detached.wrapped is raw
diff --git a/trio/_tests/test_highlevel_generic.py b/trio/_tests/test_highlevel_generic.py
index 4b2008c08c..3e9fc212a8 100644
--- a/trio/_tests/test_highlevel_generic.py
+++ b/trio/_tests/test_highlevel_generic.py
@@ -27,8 +27,9 @@ async def aclose(self) -> None:
class RecordReceiveStream(ReceiveStream):
record: list[str | tuple[str, int | None]] = attr.ib(factory=list)
- async def receive_some(self, max_bytes: int | None = None) -> None: # type: ignore[override]
+ async def receive_some(self, max_bytes: int | None = None) -> bytes:
self.record.append(("receive_some", max_bytes))
+ return b""
async def aclose(self) -> None:
self.record.append("aclose")
diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py
index 58e069f239..4a4ba95ecd 100644
--- a/trio/_tests/test_ssl.py
+++ b/trio/_tests/test_ssl.py
@@ -8,10 +8,9 @@
from contextlib import asynccontextmanager, contextmanager
from functools import partial
from ssl import SSLContext
-from typing import Any, AsyncIterator, Iterator, NoReturn
+from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, NoReturn
import pytest
-from typing_extensions import TypeAlias
from trio import StapledStream
from trio._core import MockClock
@@ -46,6 +45,9 @@
memory_stream_pair,
)
+if TYPE_CHECKING:
+ from typing_extensions import TypeAlias
+
# We have two different kinds of echo server fixtures we use for testing. The
# first is a real server written using the stdlib ssl module and blocking
# sockets. It runs in a thread and we talk to it over a real socketpair(), to
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index e4713b3950..8c32f1c49b 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -20,7 +20,6 @@
import pytest
from pytest import MonkeyPatch, WarningsRecorder
-from typing_extensions import TypeAlias
from .. import (
ClosedResourceError,
@@ -38,6 +37,9 @@
from ..lowlevel import open_process
from ..testing import MockClock, assert_no_checkpoints, wait_all_tasks_blocked
+if TYPE_CHECKING:
+ from typing_extensions import TypeAlias
+
if sys.platform == "win32":
SignalType: TypeAlias = None
else:
diff --git a/trio/_tests/test_sync.py b/trio/_tests/test_sync.py
index 448ead15da..3f9ca88650 100644
--- a/trio/_tests/test_sync.py
+++ b/trio/_tests/test_sync.py
@@ -1,16 +1,18 @@
from __future__ import annotations
import weakref
-from typing import Callable, Union
+from typing import TYPE_CHECKING, Callable, Union
import pytest
-from typing_extensions import TypeAlias
from .. import _core
from .._sync import *
from .._timeouts import sleep_forever
from ..testing import assert_checkpoints, wait_all_tasks_blocked
+if TYPE_CHECKING:
+ from typing_extensions import TypeAlias
+
async def test_Event() -> None:
e = Event()
diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py
index 4fa1e4bedf..77650bf569 100644
--- a/trio/_tests/test_threads.py
+++ b/trio/_tests/test_threads.py
@@ -9,7 +9,6 @@
import weakref
from functools import partial
from typing import (
- TYPE_CHECKING,
AsyncGenerator,
Awaitable,
Callable,
@@ -462,12 +461,10 @@ async def test_run_in_worker_thread_limiter(
# Mutating them in-place is OK though (as long as you use proper
# locking etc.).
class state:
- if TYPE_CHECKING:
- ran: int
- high_water: int
- running: int
- parked: int
- pass
+ ran: int
+ high_water: int
+ running: int
+ parked: int
state.ran = 0
state.high_water = 0
diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py
index 2b1d2c8b34..fbe2670f46 100644
--- a/trio/testing/_fake_net.py
+++ b/trio/testing/_fake_net.py
@@ -24,7 +24,6 @@
)
import attr
-from typing_extensions import Buffer, Self
import trio
from trio._util import NoPublicConstructor, final
@@ -33,7 +32,7 @@
from socket import AddressFamily, SocketKind
from types import TracebackType
- from typing_extensions import TypeAlias
+ from typing_extensions import Buffer, Self, TypeAlias
IPAddress: TypeAlias = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
From 7f0121ba5dd759b19f6332fdb4ea7b20465622b7 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 20 Oct 2023 16:35:21 +0200
Subject: [PATCH 20/35] fix a few more type errors
---
trio/_highlevel_serve_listeners.py | 6 +++---
trio/_highlevel_ssl_helpers.py | 6 ++----
trio/_subprocess_platform/waitid.py | 5 +++++
trio/testing/_fake_net.py | 28 +++++++++++++++++++++-------
4 files changed, 31 insertions(+), 14 deletions(-)
diff --git a/trio/_highlevel_serve_listeners.py b/trio/_highlevel_serve_listeners.py
index d5c7a3bdad..d949cb8f87 100644
--- a/trio/_highlevel_serve_listeners.py
+++ b/trio/_highlevel_serve_listeners.py
@@ -66,9 +66,7 @@ async def _serve_one_listener(
# https://github.com/python/typing/issues/548
-# It does never return (since _serve_one_listener never completes), but type checkers can't
-# understand nurseries.
-async def serve_listeners( # type: ignore[misc]
+async def serve_listeners(
handler: Handler[StreamT],
listeners: list[ListenerT],
*,
@@ -143,3 +141,5 @@ async def serve_listeners( # type: ignore[misc]
# but we wait until the end to call started() just in case we get an
# error or whatever.
task_status.started(listeners)
+
+ raise AssertionError("_serve_one_listener should never complete")
diff --git a/trio/_highlevel_ssl_helpers.py b/trio/_highlevel_ssl_helpers.py
index c03919d6c0..0006d7a41f 100644
--- a/trio/_highlevel_ssl_helpers.py
+++ b/trio/_highlevel_ssl_helpers.py
@@ -8,10 +8,8 @@
from ._highlevel_open_tcp_stream import DEFAULT_DELAY
from ._highlevel_socket import SocketStream
-from .abc import Stream
T = TypeVar("T")
-T_Stream = TypeVar("T_Stream", bound=Stream)
# It might have been nice to take a ssl_protocols= argument here to set up
@@ -100,7 +98,7 @@ async def open_ssl_over_tcp_listeners(
async def serve_ssl_over_tcp(
- handler: Callable[[trio.SSLStream[T_Stream]], Awaitable[object]],
+ handler: Callable[[trio.SSLStream[SocketStream]], Awaitable[object]],
port: int,
ssl_context: ssl.SSLContext,
*,
@@ -109,7 +107,7 @@ async def serve_ssl_over_tcp(
backlog: int | float | None = None,
handler_nursery: trio.Nursery | None = None,
task_status: trio.TaskStatus[
- list[trio.SSLListener[T_Stream]]
+ list[trio.SSLListener[SocketStream]]
] = trio.TASK_STATUS_IGNORED,
) -> NoReturn:
"""Listen for incoming TCP connections, and for each one start a task
diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py
index 7d941747e3..60520901b7 100644
--- a/trio/_subprocess_platform/waitid.py
+++ b/trio/_subprocess_platform/waitid.py
@@ -10,6 +10,11 @@
assert (sys.platform != "win32" and sys.platform != "darwin") or not TYPE_CHECKING
+if TYPE_CHECKING:
+
+ def sync_wait_reapable(pid: int) -> None:
+ ...
+
try:
from os import waitid
diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py
index fbe2670f46..5f6b0e9892 100644
--- a/trio/testing/_fake_net.py
+++ b/trio/testing/_fake_net.py
@@ -67,10 +67,10 @@ def _fake_err(code: int) -> NoReturn:
raise OSError(code, os.strerror(code))
-def _scatter(data: bytes, buffers: Iterable[bytes]) -> int:
+def _scatter(data: bytes, buffers: Iterable[Buffer]) -> int:
written = 0
for buf in buffers:
- next_piece = data[written : written + len(buf)]
+ next_piece = data[written : written + memoryview(buf).nbytes]
with memoryview(buf) as mbuf:
mbuf[: len(next_piece)] = next_piece
written += len(next_piece)
@@ -107,6 +107,7 @@ def from_python_sockaddr(
@attr.frozen
class UDPBinding:
local: UDPEndpoint
+ # remote: UDPEndpoint # ??
@attr.frozen
@@ -299,6 +300,10 @@ async def sendmsg(self, *args: Any) -> int:
ancdata = []
flags = 0
address = None
+
+ # This does *not* match up with socket.socket.sendmsg (!!!)
+ # https://docs.python.org/3/library/socket.html#socket.socket.sendmsg
+ # they always have (buffers, ancdata, flags, address)
if len(args) == 1:
(buffers,) = args
elif len(args) == 2:
@@ -377,10 +382,17 @@ def getsockname(self) -> tuple[str, int] | tuple[str, int, int, int]:
assert self.family == trio.socket.AF_INET6
return ("::", 0)
- def getpeername(self) -> None:
+ # TODO: This method is not tested, and seems to make incorrect assumptions. It should maybe raise NotImplementedError.
+ def getpeername(self) -> tuple[str, int] | tuple[str, int, int, int]:
self._check_closed()
if self._binding is not None:
+ assert hasattr(
+ self._binding, "remote"
+ ), "This method seems to assume that self._binding has a remote UDPEndpoint"
if self._binding.remote is not None:
+ assert isinstance(
+ self._binding.remote, UDPEndpoint
+ ), "Self._binding.remote should be a UDPEndpoint"
return self._binding.remote.as_python_sockaddr()
_fake_err(errno.ENOTCONN)
@@ -416,9 +428,11 @@ def setsockopt(
) -> None:
self._check_closed()
- if (level, optname) == (trio.socket.IPPROTO_IPV6, trio.socket.IPV6_V6ONLY):
- if not value:
- raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True")
+ if (level, optname) == (
+ trio.socket.IPPROTO_IPV6,
+ trio.socket.IPV6_V6ONLY,
+ ) and not value:
+ raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True")
raise OSError(f"FakeNet doesn't implement setsockopt({level}, {optname}, ...)")
@@ -465,7 +479,7 @@ async def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[bytes, Any]:
async def recvfrom_into(
self, buf: Buffer, nbytes: int = 0, flags: int = 0
) -> tuple[int, Any]:
- if nbytes != 0 and nbytes != len(buf):
+ if nbytes != 0 and nbytes != memoryview(buf).nbytes:
raise NotImplementedError("partial recvfrom_into")
got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into(
[buf], 0, flags
From af9b7752a8e686679a4f18a8affe538389aa7160 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 20 Oct 2023 17:15:50 +0200
Subject: [PATCH 21/35] enable disallow_any_unimported and
disallow_untyped_calls, fix generic Outcome's, clean up ugly casts introduced
by pyannotate
---
pyproject.toml | 6 ++----
trio/_core/_multierror.py | 5 ++---
trio/_core/_tests/test_guest_mode.py | 13 +++++++------
trio/_core/_tests/test_thread_cache.py | 14 +++++++-------
trio/_file_io.py | 2 +-
trio/_highlevel_ssl_helpers.py | 4 ++--
trio/_path.py | 2 +-
trio/_socket.py | 2 +-
trio/_ssl.py | 2 +-
trio/_tests/test_highlevel_ssl_helpers.py | 15 ++++++++++-----
trio/_tests/test_threads.py | 13 ++++++-------
trio/_tests/test_timeouts.py | 6 ++----
trio/_tests/test_util.py | 6 +++---
13 files changed, 45 insertions(+), 45 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 75849ec9bd..dfeab31760 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -70,16 +70,14 @@ warn_return_any = true
# Avoid subtle backsliding
disallow_any_decorated = true
disallow_any_generics = true
+disallow_any_unimported = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
+disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
check_untyped_defs = true
-# Enable once other problems are dealt with
-disallow_untyped_calls = false
-disallow_any_unimported = false # Enable once Outcome has stubs.
-
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
faulthandler_timeout = 60
diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py
index 05eaad33e3..79c1cd0c7c 100644
--- a/trio/_core/_multierror.py
+++ b/trio/_core/_multierror.py
@@ -424,9 +424,8 @@ def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackT
else:
# http://doc.pypy.org/en/latest/objspace-proxies.html
def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType:
- # Mypy refuses to believe that ProxyOperation can be imported properly
- # TODO: will need no-any-unimported if/when that's toggled on
- def controller(operation: tputil.ProxyOperation) -> Any | None:
+ # tputil.ProxyOperation is PyPy-only, but we run mypy on CPython
+ def controller(operation: tputil.ProxyOperation) -> Any | None: # type: ignore[no-any-unimported]
# Rationale for pragma: I looked fairly carefully and tried a few
# things, and AFAICT it's not actually possible to get any
# 'opname' that isn't __getattr__ or __getattribute__. So there's
diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py
index be9d4cb9c2..dd8847b448 100644
--- a/trio/_core/_tests/test_guest_mode.py
+++ b/trio/_core/_tests/test_guest_mode.py
@@ -20,7 +20,6 @@
Callable,
NoReturn,
TypeVar,
- cast,
)
import pytest
@@ -54,7 +53,7 @@ def trivial_guest_run(
in_host_after_start: Callable[[], None] | None = None,
**start_guest_run_kwargs: Any,
) -> T:
- todo: queue.Queue[tuple[str, Outcome]] = queue.Queue()
+ todo: queue.Queue[tuple[str, Outcome[T] | Callable[..., object]]] = queue.Queue()
host_thread = threading.current_thread()
@@ -76,7 +75,7 @@ def run_sync_soon_not_threadsafe(fn: Callable[[], object]) -> None:
todo.put(("run", crash))
todo.put(("run", fn))
- def done_callback(outcome: Outcome) -> None:
+ def done_callback(outcome: Outcome[T]) -> None:
nonlocal todo
todo.put(("unwrap", outcome))
@@ -95,9 +94,11 @@ def done_callback(outcome: Outcome) -> None:
while True:
op, obj = todo.get()
if op == "run":
+ assert not isinstance(obj, Outcome)
obj()
elif op == "unwrap":
- return obj.unwrap() # type: ignore[no-any-return]
+ assert isinstance(obj, Outcome)
+ return obj.unwrap()
else: # pragma: no cover
assert False
finally:
@@ -435,7 +436,7 @@ def aiotrio_run(
async def aio_main() -> T:
trio_done_fut = loop.create_future()
- def trio_done_callback(main_outcome: Outcome) -> None:
+ def trio_done_callback(main_outcome: Outcome[object]) -> None:
print(f"trio_fn finished: {main_outcome!r}")
trio_done_fut.set_result(main_outcome)
@@ -462,7 +463,7 @@ async def trio_main() -> str:
print("trio_main!")
to_trio, from_aio = trio.open_memory_channel[int](float("inf"))
- from_trio = cast("asyncio.Queue[int]", asyncio.Queue())
+ from_trio: asyncio.Queue[int] = asyncio.Queue()
aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio))
diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py
index e5b4902d8d..77fdf46664 100644
--- a/trio/_core/_tests/test_thread_cache.py
+++ b/trio/_core/_tests/test_thread_cache.py
@@ -4,7 +4,7 @@
import time
from contextlib import contextmanager
from queue import Queue
-from typing import Iterator, NoReturn, cast
+from typing import Iterator, NoReturn
import pytest
from outcome import Outcome
@@ -16,12 +16,12 @@
def test_thread_cache_basics() -> None:
- q = cast("Queue[Outcome]", Queue())
+ q: Queue[Outcome[object]] = Queue()
def fn() -> NoReturn:
raise RuntimeError("hi")
- def deliver(outcome: Outcome) -> None:
+ def deliver(outcome: Outcome[object]) -> None:
q.put(outcome)
start_thread_soon(fn, deliver)
@@ -41,9 +41,9 @@ def __call__(self) -> int:
def __del__(self) -> None:
res[0] = True
- q = cast("Queue[Outcome]", Queue())
+ q: Queue[Outcome[int]] = Queue()
- def deliver(outcome: Outcome) -> None:
+ def deliver(outcome: Outcome[int]) -> None:
q.put(outcome)
start_thread_soon(del_me(), deliver)
@@ -64,7 +64,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None:
# Make sure there are a few threads running, so if we weren't LIFO then we
# could grab the wrong one.
- q = cast("Queue[Outcome]", Queue())
+ q: Queue[Outcome[object]] = Queue()
COUNT = 5
for _ in range(COUNT):
start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result))
@@ -96,7 +96,7 @@ def test_idle_threads_exit(monkeypatch: MonkeyPatch) -> None:
# CPU.)
monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001)
- q = cast("Queue[threading.Thread]", Queue())
+ q: Queue[threading.Thread] = Queue()
start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread()))
seen_thread = q.get()
# Since the idle timeout is 0, after sleeping for 1 second, the thread
diff --git a/trio/_file_io.py b/trio/_file_io.py
index 6f30f89dd8..4a5650b453 100644
--- a/trio/_file_io.py
+++ b/trio/_file_io.py
@@ -430,7 +430,7 @@ async def open_file(
@overload
-async def open_file( # type: ignore[misc] # Any usage matches builtins.open().
+async def open_file(
file: _OpenFile,
mode: str,
buffering: int = -1,
diff --git a/trio/_highlevel_ssl_helpers.py b/trio/_highlevel_ssl_helpers.py
index 0006d7a41f..321119216f 100644
--- a/trio/_highlevel_ssl_helpers.py
+++ b/trio/_highlevel_ssl_helpers.py
@@ -22,7 +22,7 @@
# So... let's punt on that for now. Hopefully we'll be getting a new Python
# TLS API soon and can revisit this then.
async def open_ssl_over_tcp_stream(
- host: str | bytes,
+ host: str,
port: int,
*,
https_compatible: bool = False,
@@ -40,7 +40,7 @@ async def open_ssl_over_tcp_stream(
data.
Args:
- host (bytes or str): The host to connect to. We require the server
+ host (str): The host to connect to. We require the server
to have a TLS certificate valid for this hostname.
port (int): The port to connect to.
https_compatible (bool): Set this to True if you're connecting to a web
diff --git a/trio/_path.py b/trio/_path.py
index 219f706825..fb01420a75 100644
--- a/trio/_path.py
+++ b/trio/_path.py
@@ -311,7 +311,7 @@ async def open(
...
@overload
- async def open( # type: ignore[misc] # Any usage matches builtins.open().
+ async def open(
self,
mode: str,
buffering: int = -1,
diff --git a/trio/_socket.py b/trio/_socket.py
index 7f6d9b3581..bce0fcf9b3 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -1188,7 +1188,7 @@ async def sendto(self, *args: Any) -> int:
# We don't care about invalid types, sendto() will do the checking.
return await self._nonblocking_helper(
_core.wait_writable,
- _stdlib_socket.socket.sendto, # type: ignore[arg-type]
+ _stdlib_socket.socket.sendto,
*args_list,
)
diff --git a/trio/_ssl.py b/trio/_ssl.py
index 78123c9e68..9a9b608f9e 100644
--- a/trio/_ssl.py
+++ b/trio/_ssl.py
@@ -345,7 +345,7 @@ def __init__(
transport_stream: T_Stream,
ssl_context: _stdlib_ssl.SSLContext,
*,
- server_hostname: str | bytes | None = None,
+ server_hostname: str | None = None,
server_side: bool = False,
https_compatible: bool = False,
) -> None:
diff --git a/trio/_tests/test_highlevel_ssl_helpers.py b/trio/_tests/test_highlevel_ssl_helpers.py
index c1b0febbd5..8e90adb3d2 100644
--- a/trio/_tests/test_highlevel_ssl_helpers.py
+++ b/trio/_tests/test_highlevel_ssl_helpers.py
@@ -72,15 +72,20 @@ async def test_open_ssl_over_tcp_stream_and_everything_else(
client_ctx: SSLContext, # noqa: F811 # linters doesn't understand fixture
) -> None:
async with trio.open_nursery() as nursery:
- # TODO: the types are *very* funky here, this seems like an error in some signature
- # unless this is doing stuff we don't want/expect end users to do
- res: list[SSLListener[SocketListener]] = await nursery.start(
- partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1")
+ # TODO: this function wraps an SSLListener around a SocketListener, this is illegal
+ # according to current type hints, and probably for good reason. But there should
+ # maybe be a different wrapper class/function that could be used instead?
+ res: list[SSLListener[SocketListener]] = ( # type: ignore[type-var]
+ await nursery.start(
+ partial(
+ serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1"
+ )
+ )
)
(listener,) = res
async with listener:
# listener.transport_listener is of type Listener[Stream]
- tp_listener: SocketListener = listener.transport_listener
+ tp_listener: SocketListener = listener.transport_listener # type: ignore[assignment]
sockaddr = tp_listener.socket.getsockname()
hostname_resolver = FakeHostnameResolver(sockaddr)
diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py
index 77650bf569..327c35a4d9 100644
--- a/trio/_tests/test_threads.py
+++ b/trio/_tests/test_threads.py
@@ -19,7 +19,6 @@
Type,
TypeVar,
Union,
- cast,
)
import pytest
@@ -346,7 +345,7 @@ async def child(q: stdlib_queue.Queue[None], cancellable: bool) -> None:
record.append("exit")
record: list[str] = []
- q = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue())
+ q: stdlib_queue.Queue[None] = stdlib_queue.Queue()
async with _core.open_nursery() as nursery:
nursery.start_soon(child, q, True)
# Give it a chance to get started. (This is important because
@@ -394,8 +393,8 @@ def test_run_in_worker_thread_abandoned(
) -> None:
monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01)
- q1 = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue())
- q2 = cast("stdlib_queue.Queue[threading.Thread]", stdlib_queue.Queue())
+ q1: stdlib_queue.Queue[None] = stdlib_queue.Queue()
+ q2: stdlib_queue.Queue[threading.Thread] = stdlib_queue.Queue()
def thread_fn() -> None:
q1.get()
@@ -916,7 +915,7 @@ def get_tid_then_reenter() -> int:
async def test_from_thread_host_cancelled() -> None:
- queue = cast("stdlib_queue.Queue[bool]", stdlib_queue.Queue())
+ queue: stdlib_queue.Queue[bool] = stdlib_queue.Queue()
def sync_check() -> None:
from_thread_run_sync(cancel_scope.cancel)
@@ -975,7 +974,7 @@ async def async_time_bomb() -> None:
async def test_from_thread_check_cancelled() -> None:
- q = cast("stdlib_queue.Queue[str]", stdlib_queue.Queue())
+ q: stdlib_queue.Queue[str] = stdlib_queue.Queue()
async def child(cancellable: bool, scope: CancelScope) -> None:
with scope:
@@ -1055,7 +1054,7 @@ def f() -> None: # type: ignore[no-redef] # noqa: F811
async def test_from_thread_check_cancelled_raises_in_foreign_threads() -> None:
with pytest.raises(RuntimeError):
from_thread_check_cancelled()
- q = cast("stdlib_queue.Queue[Outcome]", stdlib_queue.Queue())
+ q: stdlib_queue.Queue[Outcome[object]] = stdlib_queue.Queue()
_core.start_thread_soon(from_thread_check_cancelled, lambda _: q.put(_))
with pytest.raises(RuntimeError):
q.get(timeout=1).unwrap()
diff --git a/trio/_tests/test_timeouts.py b/trio/_tests/test_timeouts.py
index 918e763faa..43fadfdaca 100644
--- a/trio/_tests/test_timeouts.py
+++ b/trio/_tests/test_timeouts.py
@@ -12,9 +12,7 @@
T = TypeVar("T")
-async def check_takes_about(
- f: Callable[[], Awaitable[T]], expected_dur: float
-) -> Awaitable[T]:
+async def check_takes_about(f: Callable[[], Awaitable[T]], expected_dur: float) -> T:
start = time.perf_counter()
result = await outcome.acapture(f)
dur = time.perf_counter() - start
@@ -41,7 +39,7 @@ async def check_takes_about(
assert (1 - 1e-8) <= (dur / expected_dur) < 1.5
# outcome is not typed
- return result.unwrap() # type: ignore[no-any-return]
+ return result.unwrap()
# How long to (attempt to) sleep for when testing. Smaller numbers make the
diff --git a/trio/_tests/test_util.py b/trio/_tests/test_util.py
index 01a7503e9d..40c2fd11bb 100644
--- a/trio/_tests/test_util.py
+++ b/trio/_tests/test_util.py
@@ -270,7 +270,7 @@ def test_fixup_module_metadata() -> None:
assert mod.SomeClass.method.__module__ == "trio.somemodule" # type: ignore[attr-defined]
assert mod.SomeClass.method.__qualname__ == "SomeClass.method" # type: ignore[attr-defined]
# Make coverage happy.
- non_trio_module.some_func()
- mod.some_func()
- mod._private()
+ non_trio_module.some_func() # type: ignore[no-untyped-call]
+ mod.some_func() # type: ignore[no-untyped-call]
+ mod._private() # type: ignore[no-untyped-call]
mod.SomeClass().method()
From 077092b694020570e54ffd84f3bbe06b0c615342 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 20 Oct 2023 17:49:12 +0200
Subject: [PATCH 22/35] readd type: ignore's incorrectly removed, fix
type-checking on non-linux
---
trio/_file_io.py | 2 +-
trio/_path.py | 2 +-
trio/_socket.py | 2 +-
trio/_subprocess_platform/waitid.py | 4 ++--
trio/_tests/test_exports.py | 3 ++-
trio/_tests/test_highlevel_open_unix_stream.py | 4 ++++
trio/_tests/test_subprocess.py | 12 +++++++++---
7 files changed, 20 insertions(+), 9 deletions(-)
diff --git a/trio/_file_io.py b/trio/_file_io.py
index 4a5650b453..6f30f89dd8 100644
--- a/trio/_file_io.py
+++ b/trio/_file_io.py
@@ -430,7 +430,7 @@ async def open_file(
@overload
-async def open_file(
+async def open_file( # type: ignore[misc] # Any usage matches builtins.open().
file: _OpenFile,
mode: str,
buffering: int = -1,
diff --git a/trio/_path.py b/trio/_path.py
index fb01420a75..219f706825 100644
--- a/trio/_path.py
+++ b/trio/_path.py
@@ -311,7 +311,7 @@ async def open(
...
@overload
- async def open(
+ async def open( # type: ignore[misc] # Any usage matches builtins.open().
self,
mode: str,
buffering: int = -1,
diff --git a/trio/_socket.py b/trio/_socket.py
index bce0fcf9b3..7f6d9b3581 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -1188,7 +1188,7 @@ async def sendto(self, *args: Any) -> int:
# We don't care about invalid types, sendto() will do the checking.
return await self._nonblocking_helper(
_core.wait_writable,
- _stdlib_socket.socket.sendto,
+ _stdlib_socket.socket.sendto, # type: ignore[arg-type]
*args_list,
)
diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py
index 60520901b7..03a464d6a9 100644
--- a/trio/_subprocess_platform/waitid.py
+++ b/trio/_subprocess_platform/waitid.py
@@ -8,14 +8,14 @@
from .._sync import CapacityLimiter, Event
from .._threads import to_thread_run_sync
-assert (sys.platform != "win32" and sys.platform != "darwin") or not TYPE_CHECKING
-
if TYPE_CHECKING:
def sync_wait_reapable(pid: int) -> None:
...
+assert (sys.platform != "win32" and sys.platform != "darwin") or not TYPE_CHECKING
+
try:
from os import waitid
diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py
index 47671b74a3..74a87f7115 100644
--- a/trio/_tests/test_exports.py
+++ b/trio/_tests/test_exports.py
@@ -157,8 +157,9 @@ def no_underscores(symbols: Iterable[str]) -> set[str]:
skip_if_optional_else_raise(error)
linter = PyLinter()
+ assert module.__file__ is not None
ast = linter.get_ast(module.__file__, modname)
- static_names = no_underscores(ast)
+ static_names = no_underscores(ast) # type: ignore[arg-type]
elif tool == "jedi":
try:
import jedi
diff --git a/trio/_tests/test_highlevel_open_unix_stream.py b/trio/_tests/test_highlevel_open_unix_stream.py
index 045820fccf..0ff11209a7 100644
--- a/trio/_tests/test_highlevel_open_unix_stream.py
+++ b/trio/_tests/test_highlevel_open_unix_stream.py
@@ -1,12 +1,16 @@
import os
import socket
+import sys
import tempfile
+from typing import TYPE_CHECKING
import pytest
from trio import Path, open_unix_socket
from trio._highlevel_open_unix_stream import close_on_error
+assert not TYPE_CHECKING or sys.platform != "win32"
+
if not hasattr(socket, "AF_UNIX"):
pytestmark = pytest.mark.skip("Needs unix socket support")
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index 8c32f1c49b..51389da518 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -79,7 +79,7 @@ def SLEEP(seconds: int) -> list[str]:
def got_signal(proc: Process, sig: SignalType) -> bool:
- if posix:
+ if (not TYPE_CHECKING and posix) or sys.platform != "win32":
return proc.returncode == -sig
else:
return proc.returncode != 0
@@ -444,7 +444,8 @@ async def test_stderr_stdout(background_process: BackgroundProcessType) -> None:
async def test_errors() -> None:
with pytest.raises(TypeError) as excinfo:
- await open_process(["ls"], encoding="utf-8") # type: ignore[call-overload]
+ # call-overload on unix, call-arg on windows
+ await open_process(["ls"], encoding="utf-8") # type: ignore
assert "unbuffered byte streams" in str(excinfo.value)
assert "the 'encoding' option is not supported" in str(excinfo.value)
@@ -480,13 +481,15 @@ async def test_one_signal(
# tries to handle SIGINT during startup. SIGUSR1's default disposition is
# to terminate the target process, and Python doesn't try to do anything
# clever to handle it.
- if posix:
+ if (not TYPE_CHECKING and posix) or sys.platform != "win32":
await test_one_signal(lambda proc: proc.send_signal(SIGUSR1), SIGUSR1)
@pytest.mark.skipif(not posix, reason="POSIX specific")
@background_process_param
async def test_wait_reapable_fails(background_process: BackgroundProcessType) -> None:
+ if TYPE_CHECKING and sys.platform == "win32":
+ return
old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
try:
# With SIGCHLD disabled, the wait() syscall will wait for the
@@ -510,6 +513,9 @@ def test_waitid_eintr() -> None:
# ourselves) but the test works on all waitid platforms.
from .._subprocess_platform import wait_child_exiting
+ if TYPE_CHECKING and sys.platform == "win32":
+ return
+
if not wait_child_exiting.__module__.endswith("waitid"):
pytest.skip("waitid only")
from .._subprocess_platform.waitid import sync_wait_reapable
From fcdbdbc410189f6f00e729bef45df0b397b5b484 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 23 Oct 2023 16:27:20 +0200
Subject: [PATCH 23/35] fix ruff/mypy issues
---
trio/_tests/test_file_io.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/trio/_tests/test_file_io.py b/trio/_tests/test_file_io.py
index efcfedeb09..85b8324c56 100644
--- a/trio/_tests/test_file_io.py
+++ b/trio/_tests/test_file_io.py
@@ -83,7 +83,8 @@ def unsupported_attr(self) -> None: # pragma: no cover
assert hasattr(async_file.wrapped, "unsupported_attr")
with pytest.raises(AttributeError):
- async_file.unsupported_attr # noqa: B018 # "useless expression"
+ # B018 "useless expression"
+ async_file.unsupported_attr # type: ignore[attr-defined] # noqa: B018
def test_type_stubs_match_lists() -> None:
@@ -247,7 +248,8 @@ async def test_aclose_cancelled(path: pathlib.Path) -> None:
async def test_detach_rewraps_asynciobase(tmp_path: pathlib.Path) -> None:
tmp_file = tmp_path / "filename"
tmp_file.touch()
- with open(tmp_file, mode="rb", buffering=0) as raw:
+ # flake8-async does not like opening files in async mode
+ with open(tmp_file, mode="rb", buffering=0) as raw: # noqa: ASYNC101
buffered = io.BufferedReader(raw)
async_file = trio.wrap_file(buffered)
From 16631e99d10b552a7541c1f687f37d12534867bc Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 23 Oct 2023 16:35:40 +0200
Subject: [PATCH 24/35] fix B904 error
---
trio/_threads.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/trio/_threads.py b/trio/_threads.py
index 3a8d83a55a..d44e4b68f8 100644
--- a/trio/_threads.py
+++ b/trio/_threads.py
@@ -409,10 +409,10 @@ def from_thread_check_cancelled() -> None:
"""
try:
raise_cancel = PARENT_TASK_DATA.cancel_register[0]
- except AttributeError:
+ except AttributeError as exc:
raise RuntimeError(
"this thread wasn't created by Trio, can't check for cancellation"
- )
+ ) from exc
if raise_cancel is not None:
raise_cancel()
From 208b1a3f14f91b5bc4e5ce8f991c61725042cc50 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 25 Oct 2023 13:51:26 +0200
Subject: [PATCH 25/35] fix after merge + coolcat review
---
trio/_tests/test_subprocess.py | 3 ++-
trio/_tests/test_threads.py | 2 +-
trio/_tests/test_timeouts.py | 1 -
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index 51389da518..7a5fafdddb 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -9,6 +9,7 @@
from functools import partial
from pathlib import Path as SyncPath
from signal import Signals
+from types import FrameType
from typing import (
TYPE_CHECKING,
Any,
@@ -523,7 +524,7 @@ def test_waitid_eintr() -> None:
got_alarm = False
sleeper = subprocess.Popen(["sleep", "3600"])
- def on_alarm(sig: object, frame: object) -> None:
+ def on_alarm(sig: int, frame: FrameType | None) -> None:
nonlocal got_alarm
got_alarm = True
sleeper.kill()
diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py
index 90755501bc..637a035d63 100644
--- a/trio/_tests/test_threads.py
+++ b/trio/_tests/test_threads.py
@@ -1062,7 +1062,7 @@ async def test_from_thread_check_cancelled_raises_in_foreign_threads() -> None:
@slow
-async def test_reentry_doesnt_deadlock():
+async def test_reentry_doesnt_deadlock() -> None:
# Regression test for issue noticed in GH-2827
# The failure mode is to hang the whole test suite, unfortunately.
# XXX consider running this in a subprocess with a timeout, if it comes up again!
diff --git a/trio/_tests/test_timeouts.py b/trio/_tests/test_timeouts.py
index 43fadfdaca..c6def0bf9e 100644
--- a/trio/_tests/test_timeouts.py
+++ b/trio/_tests/test_timeouts.py
@@ -38,7 +38,6 @@ async def check_takes_about(f: Callable[[], Awaitable[T]], expected_dur: float)
# started as a 1 ULP error at a different dynamic range.)
assert (1 - 1e-8) <= (dur / expected_dur) < 1.5
- # outcome is not typed
return result.unwrap()
From cfbbcea902fd8cd190f4dbb0b03d54a1f8bfb0c8 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 25 Oct 2023 15:21:52 +0200
Subject: [PATCH 26/35] add --show-fixes to ruff in pre-commit, test
show_diff_on_failure
---
.pre-commit-config.yaml | 2 ++
1 file changed, 2 insertions(+)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index ca4811ed2c..deda376520 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -19,6 +19,7 @@ repos:
- id: ruff
types: [file]
types_or: [python, pyi, toml]
+ args: ["--show-fixes"]
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
hooks:
@@ -30,3 +31,4 @@ ci:
autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate"
autoupdate_schedule: weekly
submodules: false
+ show_diff_on_failure: true
From 617038449537e7ed6bfd0f377ee6b360fd49a910 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 25 Oct 2023 15:36:08 +0200
Subject: [PATCH 27/35] autofix from pre-commit
---
.pre-commit-config.yaml | 1 -
trio/_core/_tests/test_multierror.py | 4 ++--
2 files changed, 2 insertions(+), 3 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index deda376520..cfa4347f63 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -31,4 +31,3 @@ ci:
autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate"
autoupdate_schedule: weekly
submodules: false
- show_diff_on_failure: true
diff --git a/trio/_core/_tests/test_multierror.py b/trio/_core/_tests/test_multierror.py
index 589fd0eea6..83a0489653 100644
--- a/trio/_core/_tests/test_multierror.py
+++ b/trio/_core/_tests/test_multierror.py
@@ -10,7 +10,7 @@
from pathlib import Path
from traceback import extract_tb, print_exception
from types import TracebackType
-from typing import Callable, List, NoReturn
+from typing import Callable, NoReturn
import pytest
@@ -400,7 +400,7 @@ def simple_filter(exc: BaseException) -> Exception | RuntimeError:
gc.garbage.clear()
-def assert_match_in_seq(pattern_list: List[str], string: str) -> None:
+def assert_match_in_seq(pattern_list: list[str], string: str) -> None:
offset = 0
print("looking for pattern matches...")
for pattern in pattern_list:
From 90d6cab092cbfd3f456fde373abbc623d4733a8c Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 25 Oct 2023 15:57:28 +0200
Subject: [PATCH 28/35] fix RTD build error
---
docs/source/conf.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 06b661f126..c56ce12925 100755
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -71,6 +71,8 @@
# "types.FrameType" is more helpful than just "frame"
"FrameType": "types.FrameType",
"Context": "OpenSSL.SSL.Context",
+ # SSLListener.accept's return type is seen as trio._ssl.SSLStream
+ "SSLStream": "trio.SSLStream",
}
From 95fa49e762773f82ac422fd34534a4621d591623 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 25 Oct 2023 16:05:31 +0200
Subject: [PATCH 29/35] fix most codecov issues
---
.coveragerc | 2 +-
trio/_core/_tests/test_guest_mode.py | 2 +-
trio/_core/_tests/test_multierror.py | 2 +-
trio/_highlevel_serve_listeners.py | 2 +-
trio/_tests/test_ssl.py | 6 +++---
5 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/.coveragerc b/.coveragerc
index 4911012653..604c14775f 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -20,7 +20,7 @@ skip_covered = True
exclude_lines =
pragma: no cover
abc.abstractmethod
- if TYPE_CHECKING:
+ if TYPE_CHECKING.*:
if _t.TYPE_CHECKING:
if t.TYPE_CHECKING:
@overload
diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py
index af80fd8218..175f687125 100644
--- a/trio/_core/_tests/test_guest_mode.py
+++ b/trio/_core/_tests/test_guest_mode.py
@@ -483,7 +483,7 @@ async def trio_main() -> str:
aio_task.cancel()
return "trio-main-done"
- raise AssertionError("should never be reached")
+ raise AssertionError("should never be reached") # pragma: no cov
async def aio_pingpong(
from_trio: asyncio.Queue[int], to_trio: MemorySendChannel[int]
diff --git a/trio/_core/_tests/test_multierror.py b/trio/_core/_tests/test_multierror.py
index 83a0489653..91e5fc73b6 100644
--- a/trio/_core/_tests/test_multierror.py
+++ b/trio/_core/_tests/test_multierror.py
@@ -69,7 +69,7 @@ def get_exc(raiser: Callable[[], NoReturn]) -> BaseException:
raiser()
except Exception as exc:
return exc
- raise AssertionError("raiser should always raise")
+ raise AssertionError("raiser should always raise") # pragma: no cov
def get_tb(raiser: Callable[[], NoReturn]) -> TracebackType | None:
diff --git a/trio/_highlevel_serve_listeners.py b/trio/_highlevel_serve_listeners.py
index d949cb8f87..0940e102a5 100644
--- a/trio/_highlevel_serve_listeners.py
+++ b/trio/_highlevel_serve_listeners.py
@@ -142,4 +142,4 @@ async def serve_listeners(
# error or whatever.
task_status.started(listeners)
- raise AssertionError("_serve_one_listener should never complete")
+ raise AssertionError("_serve_one_listener should never complete") # pragma: no cov
diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py
index e0df547b39..8e904679e9 100644
--- a/trio/_tests/test_ssl.py
+++ b/trio/_tests/test_ssl.py
@@ -778,13 +778,13 @@ async def wait_send_all_might_not_block(self) -> None:
# define methods that are abstract in Stream
async def aclose(self) -> None:
- raise AssertionError("Should not get called")
+ raise AssertionError("Should not get called") # pragma: no cov
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
- raise AssertionError("Should not get called")
+ raise AssertionError("Should not get called") # pragma: no cov
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
- raise AssertionError("Should not get called")
+ raise AssertionError("Should not get called") # pragma: no cov
ctx = ssl.create_default_context()
s = SSLStream(NotAStream(), ctx, server_hostname="x")
From 027279a806ec248cdbe44eccb269f68782477d2d Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 25 Oct 2023 16:41:37 +0200
Subject: [PATCH 30/35] pragma: no cov -> pragma: no cover
---
trio/_core/_tests/test_guest_mode.py | 2 +-
trio/_core/_tests/test_multierror.py | 2 +-
trio/_highlevel_serve_listeners.py | 4 +++-
trio/_tests/test_ssl.py | 6 +++---
4 files changed, 8 insertions(+), 6 deletions(-)
diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py
index 175f687125..a648832b4c 100644
--- a/trio/_core/_tests/test_guest_mode.py
+++ b/trio/_core/_tests/test_guest_mode.py
@@ -483,7 +483,7 @@ async def trio_main() -> str:
aio_task.cancel()
return "trio-main-done"
- raise AssertionError("should never be reached") # pragma: no cov
+ raise AssertionError("should never be reached") # pragma: no cover
async def aio_pingpong(
from_trio: asyncio.Queue[int], to_trio: MemorySendChannel[int]
diff --git a/trio/_core/_tests/test_multierror.py b/trio/_core/_tests/test_multierror.py
index 91e5fc73b6..09f9b6a271 100644
--- a/trio/_core/_tests/test_multierror.py
+++ b/trio/_core/_tests/test_multierror.py
@@ -69,7 +69,7 @@ def get_exc(raiser: Callable[[], NoReturn]) -> BaseException:
raiser()
except Exception as exc:
return exc
- raise AssertionError("raiser should always raise") # pragma: no cov
+ raise AssertionError("raiser should always raise") # pragma: no cover
def get_tb(raiser: Callable[[], NoReturn]) -> TracebackType | None:
diff --git a/trio/_highlevel_serve_listeners.py b/trio/_highlevel_serve_listeners.py
index 0940e102a5..ec5a0efb3c 100644
--- a/trio/_highlevel_serve_listeners.py
+++ b/trio/_highlevel_serve_listeners.py
@@ -142,4 +142,6 @@ async def serve_listeners(
# error or whatever.
task_status.started(listeners)
- raise AssertionError("_serve_one_listener should never complete") # pragma: no cov
+ raise AssertionError(
+ "_serve_one_listener should never complete"
+ ) # pragma: no cover
diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py
index 8e904679e9..cd404a1c97 100644
--- a/trio/_tests/test_ssl.py
+++ b/trio/_tests/test_ssl.py
@@ -778,13 +778,13 @@ async def wait_send_all_might_not_block(self) -> None:
# define methods that are abstract in Stream
async def aclose(self) -> None:
- raise AssertionError("Should not get called") # pragma: no cov
+ raise AssertionError("Should not get called") # pragma: no cover
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
- raise AssertionError("Should not get called") # pragma: no cov
+ raise AssertionError("Should not get called") # pragma: no cover
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
- raise AssertionError("Should not get called") # pragma: no cov
+ raise AssertionError("Should not get called") # pragma: no cover
ctx = ssl.create_default_context()
s = SSLStream(NotAStream(), ctx, server_hostname="x")
From fc0a6d3b5bd360ac907dcb2bc8254d73b98a3350 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Sun, 29 Oct 2023 18:54:06 +0000
Subject: [PATCH 31/35] [pre-commit.ci] auto fixes from pre-commit.com hooks
---
trio/_tests/test_sync.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/trio/_tests/test_sync.py b/trio/_tests/test_sync.py
index 3f9ca88650..9179c8a5ae 100644
--- a/trio/_tests/test_sync.py
+++ b/trio/_tests/test_sync.py
@@ -247,7 +247,7 @@ async def test_Semaphore_bounded() -> None:
@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__)
async def test_Lock_and_StrictFIFOLock(
- lockcls: type[Lock] | type[StrictFIFOLock],
+ lockcls: type[Lock | StrictFIFOLock],
) -> None:
l = lockcls() # noqa
assert not l.locked()
From 81f8bf0b8f0aa2c43844579a6895fb9759c8c27a Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 30 Oct 2023 12:31:21 +0100
Subject: [PATCH 32/35] minor changes after review comments
---
trio/_core/_tests/test_instrumentation.py | 5 ++++-
trio/_core/_tests/test_ki.py | 1 +
trio/_core/_tests/tutil.py | 6 ++++--
trio/_highlevel_ssl_helpers.py | 4 ++--
trio/_ssl.py | 2 +-
trio/_subprocess_platform/waitid.py | 6 ------
trio/_tests/test_exports.py | 3 +--
trio/_tests/test_highlevel_serve_listeners.py | 17 ++++++++---------
trio/_tests/test_subprocess.py | 10 ++++++++--
9 files changed, 29 insertions(+), 25 deletions(-)
diff --git a/trio/_core/_tests/test_instrumentation.py b/trio/_core/_tests/test_instrumentation.py
index 1c35a17ee6..f743f2b3d4 100644
--- a/trio/_core/_tests/test_instrumentation.py
+++ b/trio/_core/_tests/test_instrumentation.py
@@ -219,7 +219,10 @@ async def main() -> None:
# Changing the set of hooks implemented by an instrument after
# it's installed doesn't make them start being called right away
- instrument.before_task_step = record.append # type: ignore[assignment]
+ instrument.before_task_step = ( # type: ignore[method-assign]
+ record.append # type: ignore[assignment] # append is pos-only
+ )
+
await _core.checkpoint()
await _core.checkpoint()
assert len(record) == 0
diff --git a/trio/_core/_tests/test_ki.py b/trio/_core/_tests/test_ki.py
index bc4a0192af..cd98bc9bca 100644
--- a/trio/_core/_tests/test_ki.py
+++ b/trio/_core/_tests/test_ki.py
@@ -145,6 +145,7 @@ def protected_manager() -> Iterator[None]:
raise KeyError
+# the async_generator package isn't typed, hence all the type: ignores
@pytest.mark.skipif(async_generator is None, reason="async_generator not installed")
async def test_async_generator_agen_protection() -> None:
@_core.enable_ki_protection
diff --git a/trio/_core/_tests/tutil.py b/trio/_core/_tests/tutil.py
index 1d49d8b262..6ed9b5fe14 100644
--- a/trio/_core/_tests/tutil.py
+++ b/trio/_core/_tests/tutil.py
@@ -98,8 +98,10 @@ def restore_unraisablehook() -> Generator[None, None, None]:
sys.unraisablehook = prev
-# template is like:
-# [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 2.2, 3] or [1, 2.2, 2.1, 3]
+# Used to check sequences that might have some elements out of order.
+# Example usage:
+# The sequences [1, 2.1, 2.2, 3] and [1, 2.2, 2.1, 3] are both
+# matched by the template [1, {2.1, 2.2}, 3]
def check_sequence_matches(seq: Sequence[T], template: Iterable[T | set[T]]) -> None:
i = 0
for pattern in template:
diff --git a/trio/_highlevel_ssl_helpers.py b/trio/_highlevel_ssl_helpers.py
index e70a8db47d..3215fcf969 100644
--- a/trio/_highlevel_ssl_helpers.py
+++ b/trio/_highlevel_ssl_helpers.py
@@ -22,7 +22,7 @@
# So... let's punt on that for now. Hopefully we'll be getting a new Python
# TLS API soon and can revisit this then.
async def open_ssl_over_tcp_stream(
- host: str,
+ host: str | bytes,
port: int,
*,
https_compatible: bool = False,
@@ -40,7 +40,7 @@ async def open_ssl_over_tcp_stream(
data.
Args:
- host (str): The host to connect to. We require the server
+ host (bytes or str): The host to connect to. We require the server
to have a TLS certificate valid for this hostname.
port (int): The port to connect to.
https_compatible (bool): Set this to True if you're connecting to a web
diff --git a/trio/_ssl.py b/trio/_ssl.py
index 4ef8a15721..d2e1c73da5 100644
--- a/trio/_ssl.py
+++ b/trio/_ssl.py
@@ -345,7 +345,7 @@ def __init__(
transport_stream: T_Stream,
ssl_context: _stdlib_ssl.SSLContext,
*,
- server_hostname: str | None = None,
+ server_hostname: str | bytes | None = None,
server_side: bool = False,
https_compatible: bool = False,
) -> None:
diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py
index 03a464d6a9..756741218f 100644
--- a/trio/_subprocess_platform/waitid.py
+++ b/trio/_subprocess_platform/waitid.py
@@ -8,12 +8,6 @@
from .._sync import CapacityLimiter, Event
from .._threads import to_thread_run_sync
-if TYPE_CHECKING:
-
- def sync_wait_reapable(pid: int) -> None:
- ...
-
-
assert (sys.platform != "win32" and sys.platform != "darwin") or not TYPE_CHECKING
try:
diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py
index ec10edf4eb..7b38137887 100644
--- a/trio/_tests/test_exports.py
+++ b/trio/_tests/test_exports.py
@@ -2,6 +2,7 @@
import __future__ # Regular import, not special!
+import enum
import functools
import importlib
import inspect
@@ -446,8 +447,6 @@ def lookup_symbol(symbol: str) -> dict[str, str]:
extra = {e for e in extra if not e.endswith("AttrsAttributes__")}
assert len(extra) == before - 1
- import enum
-
# mypy does not see these attributes in Enum subclasses
if (
tool == "mypy"
diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py
index 060ba46ed7..75cadfd8aa 100644
--- a/trio/_tests/test_highlevel_serve_listeners.py
+++ b/trio/_tests/test_highlevel_serve_listeners.py
@@ -21,25 +21,26 @@
# types are somewhat tentative - I just bruteforced them until I got something that didn't
# give errors
-TypeThing = StapledStream[MemorySendStream, MemoryReceiveStream]
+StapledMemoryStream = StapledStream[MemorySendStream, MemoryReceiveStream]
@attr.s(hash=False, eq=False)
-class MemoryListener(trio.abc.Listener[TypeThing]):
+class MemoryListener(trio.abc.Listener[StapledMemoryStream]):
closed: bool = attr.ib(default=False)
accepted_streams: list[trio.abc.Stream] = attr.ib(factory=list)
queued_streams: tuple[
- MemorySendChannel[TypeThing], MemoryReceiveChannel[TypeThing]
- ] = attr.ib(factory=(lambda: trio.open_memory_channel[TypeThing](1)))
+ MemorySendChannel[StapledMemoryStream],
+ MemoryReceiveChannel[StapledMemoryStream],
+ ] = attr.ib(factory=(lambda: trio.open_memory_channel[StapledMemoryStream](1)))
accept_hook: Callable[[], Awaitable[object]] | None = attr.ib(default=None)
- async def connect(self) -> StapledStream[MemorySendStream, MemoryReceiveStream]:
+ async def connect(self) -> StapledMemoryStream:
assert not self.closed
client, server = memory_stream_pair()
await self.queued_streams[0].send(server)
return client
- async def accept(self) -> TypeThing:
+ async def accept(self) -> StapledMemoryStream:
await trio.lowlevel.checkpoint()
assert not self.closed
if self.accept_hook is not None:
@@ -63,9 +64,7 @@ def close_hook() -> None:
assert trio.current_effective_deadline() == float("-inf")
record.append("closed")
- async def handler(
- stream: StapledStream[MemorySendStream, MemoryReceiveStream]
- ) -> None:
+ async def handler(stream: StapledMemoryStream) -> None:
await stream.send_all(b"123")
assert await stream.receive_some(10) == b"456"
stream.send_stream.close_hook = close_hook
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index 7a5fafdddb..eb8bfd1c43 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -514,12 +514,18 @@ def test_waitid_eintr() -> None:
# ourselves) but the test works on all waitid platforms.
from .._subprocess_platform import wait_child_exiting
- if TYPE_CHECKING and sys.platform == "win32":
+ if TYPE_CHECKING and (sys.platform == "win32" or sys.platform == "darwin"):
return
if not wait_child_exiting.__module__.endswith("waitid"):
pytest.skip("waitid only")
- from .._subprocess_platform.waitid import sync_wait_reapable
+
+ # despite the TYPE_CHECKING early return silencing warnings about signal.SIGALRM etc
+ # this import is still checked on win32&darwin and raises [attr-defined].
+ # Linux doesn't raise [attr-defined] though, so we need [unused-ignore]
+ from .._subprocess_platform.waitid import (
+ sync_wait_reapable, # type: ignore[attr-defined, unused-ignore]
+ )
got_alarm = False
sleeper = subprocess.Popen(["sleep", "3600"])
From 3935a8269763cad257ba29a0759974be53f2c7ee Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 30 Oct 2023 12:38:50 +0100
Subject: [PATCH 33/35] fix incorrect suppression found by enabling
typechecking (!), move type: ignore to the correct line after incorrect
autofix by black
---
trio/_tests/test_ssl.py | 6 ++----
trio/_tests/test_subprocess.py | 6 +++---
2 files changed, 5 insertions(+), 7 deletions(-)
diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py
index cd404a1c97..e0b76d812a 100644
--- a/trio/_tests/test_ssl.py
+++ b/trio/_tests/test_ssl.py
@@ -5,7 +5,7 @@
import ssl
import sys
import threading
-from contextlib import asynccontextmanager, contextmanager
+from contextlib import asynccontextmanager, contextmanager, suppress
from functools import partial
from ssl import SSLContext
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, NoReturn
@@ -116,10 +116,8 @@ def ssl_echo_serve_sync(
# respond in kind but it's legal for them to have already
# gone away.
exceptions = (BrokenPipeError, ssl.SSLZeroReturnError)
- try:
+ with suppress(*exceptions):
wrapped.unwrap()
- except exceptions:
- pass
return
wrapped.sendall(data)
# This is an obscure workaround for an openssl bug. In server mode, in
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index eb8bfd1c43..c901f6f29e 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -231,7 +231,7 @@ async def check_output(stream: Stream, expected: bytes) -> None:
nursery.start_soon(check_output, proc.stderr, msg[::-1])
assert not nursery.cancel_scope.cancelled_caught
- assert 0 == await proc.wait()
+ assert await proc.wait() == 0
@background_process_param
@@ -523,8 +523,8 @@ def test_waitid_eintr() -> None:
# despite the TYPE_CHECKING early return silencing warnings about signal.SIGALRM etc
# this import is still checked on win32&darwin and raises [attr-defined].
# Linux doesn't raise [attr-defined] though, so we need [unused-ignore]
- from .._subprocess_platform.waitid import (
- sync_wait_reapable, # type: ignore[attr-defined, unused-ignore]
+ from .._subprocess_platform.waitid import ( # type: ignore[attr-defined, unused-ignore]
+ sync_wait_reapable,
)
got_alarm = False
From 4d25e75f5a7f393e0e75741257381d7336fdf496 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 30 Oct 2023 13:08:33 +0100
Subject: [PATCH 34/35] mark lines with # pragma: no cover
---
trio/_tests/test_tracing.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/trio/_tests/test_tracing.py b/trio/_tests/test_tracing.py
index 1b9ea02b5f..5cf758c6b6 100644
--- a/trio/_tests/test_tracing.py
+++ b/trio/_tests/test_tracing.py
@@ -22,9 +22,9 @@ async def coro2_async_gen(event: trio.Event) -> AsyncGenerator[None, None]:
await trio.lowlevel.checkpoint()
yield
await coro1(event)
- yield
- await trio.lowlevel.checkpoint()
- yield
+ yield # pragma: no cover
+ await trio.lowlevel.checkpoint() # pragma: no cover
+ yield # pragma: no cover
async def coro3_async_gen(event: trio.Event) -> None:
From 43c7d6da9f56e9de33ce33bf463ec8bba4be799c Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 30 Oct 2023 13:14:24 +0100
Subject: [PATCH 35/35] remove redundant tuple
---
trio/_tests/test_ssl.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py
index e0b76d812a..13decd5c72 100644
--- a/trio/_tests/test_ssl.py
+++ b/trio/_tests/test_ssl.py
@@ -115,8 +115,7 @@ def ssl_echo_serve_sync(
# other side has initiated a graceful shutdown; we try to
# respond in kind but it's legal for them to have already
# gone away.
- exceptions = (BrokenPipeError, ssl.SSLZeroReturnError)
- with suppress(*exceptions):
+ with suppress(BrokenPipeError, ssl.SSLZeroReturnError):
wrapped.unwrap()
return
wrapped.sendall(data)