From 3d7ecaccfad6a9d5194f4adbe8fc736c5e907668 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 15 Oct 2023 15:14:54 +0200 Subject: [PATCH 01/35] manual annotations --- pyproject.toml | 27 ++++--- trio/_core/__init__.py | 2 +- trio/_core/_tests/test_guest_mode.py | 24 +++--- trio/_core/_tests/test_instrumentation.py | 97 ++++++++++++----------- trio/_core/_tests/test_ki.py | 76 +++++++++--------- trio/_core/_tests/test_local.py | 51 ++++++------ trio/_core/_tests/test_thread_cache.py | 8 +- trio/_core/_tests/test_unbounded_queue.py | 32 ++++---- trio/lowlevel.py | 1 + 9 files changed, 175 insertions(+), 143 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 73b110ebac..e12f95c309 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,25 +80,34 @@ disallow_untyped_defs = true check_untyped_defs = true disallow_untyped_calls = false -# files not yet fully typed + +# partially typed tests [[tool.mypy.overrides]] module = [ -# internal -"trio/_windows_pipes", - -# tests -"trio/testing/_fake_net", "trio/_core/_tests/test_guest_mode", -"trio/_core/_tests/test_instrumentation", "trio/_core/_tests/test_ki", -"trio/_core/_tests/test_local", "trio/_core/_tests/test_mock_clock", "trio/_core/_tests/test_multierror", "trio/_core/_tests/test_multierror_scripts/ipython_custom_exc", "trio/_core/_tests/test_multierror_scripts/simple_excepthook", "trio/_core/_tests/test_parking_lot", "trio/_core/_tests/test_thread_cache", -"trio/_core/_tests/test_unbounded_queue", +] +check_untyped_defs = true +disallow_any_decorated = false +disallow_any_generics = false +disallow_any_unimported = false +disallow_incomplete_defs = false +disallow_untyped_defs = false + +# files not yet fully typed +[[tool.mypy.overrides]] +module = [ +# internal +"trio/_windows_pipes", + +# tests +"trio/testing/_fake_net", "trio/_tests/test_exports", "trio/_tests/test_file_io", "trio/_tests/test_highlevel_generic", diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index b9bd0d8cc4..71f5f17eb2 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -18,7 +18,7 @@ WouldBlock, ) from ._ki import currently_ki_protected, disable_ki_protection, enable_ki_protection -from ._local import RunVar +from ._local import RunVar, RunVarToken from ._mock_clock import MockClock from ._parking_lot import ParkingLot, ParkingLotStatistics diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py index 80180be805..21980b9b72 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import contextvars import queue @@ -10,8 +12,10 @@ import warnings from functools import partial from math import inf +from typing import Callable import pytest +from outcome import Outcome import trio import trio.testing @@ -27,11 +31,11 @@ # - final result is returned # - any unhandled exceptions cause an immediate crash def trivial_guest_run(trio_fn, *, in_host_after_start=None, **start_guest_run_kwargs): - todo = queue.Queue() + todo: queue.Queue[tuple[str, Outcome]] = queue.Queue() host_thread = threading.current_thread() - def run_sync_soon_threadsafe(fn): + def run_sync_soon_threadsafe(fn: Callable): nonlocal todo if host_thread is threading.current_thread(): # pragma: no cover crash = partial( @@ -40,7 +44,7 @@ def run_sync_soon_threadsafe(fn): todo.put(("run", crash)) todo.put(("run", fn)) - def run_sync_soon_not_threadsafe(fn): + def run_sync_soon_not_threadsafe(fn: Callable): nonlocal todo if host_thread is not threading.current_thread(): # pragma: no cover crash = partial( @@ -49,7 +53,7 @@ def run_sync_soon_not_threadsafe(fn): todo.put(("run", crash)) todo.put(("run", fn)) - def done_callback(outcome): + def done_callback(outcome: Outcome): nonlocal todo todo.put(("unwrap", outcome)) @@ -322,18 +326,18 @@ async def get_woken_by_host_deadline(watb_cscope): # 'sit_in_wait_all_tasks_blocked', we want the test to # actually end. So in after_io_wait we schedule a second host # call to tear things down. - class InstrumentHelper: - def __init__(self): + class InstrumentHelper(trio._abc.Instrument): + def __init__(self) -> None: self.primed = False - def before_io_wait(self, timeout): + def before_io_wait(self, timeout: float) -> None: print(f"before_io_wait({timeout})") if timeout == 9999: # pragma: no branch assert not self.primed in_host(lambda: set_deadline(cscope, 1e9)) self.primed = True - def after_io_wait(self, timeout): + def after_io_wait(self, timeout: float) -> None: if self.primed: # pragma: no branch print("instrument triggered") in_host(lambda: cscope.cancel()) @@ -429,8 +433,8 @@ def test_guest_mode_on_asyncio(): async def trio_main(): print("trio_main!") - to_trio, from_aio = trio.open_memory_channel(float("inf")) - from_trio = asyncio.Queue() + to_trio, from_aio = trio.open_memory_channel[int](float("inf")) + from_trio = asyncio.Queue[int]() aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio)) diff --git a/trio/_core/_tests/test_instrumentation.py b/trio/_core/_tests/test_instrumentation.py index 498a3eb272..8b103dbfdc 100644 --- a/trio/_core/_tests/test_instrumentation.py +++ b/trio/_core/_tests/test_instrumentation.py @@ -1,32 +1,37 @@ +from __future__ import annotations + +from typing import Container, Iterable, NoReturn + import attr import pytest from ... import _abc, _core +from ...lowlevel import Task from .tutil import check_sequence_matches @attr.s(eq=False, hash=False) -class TaskRecorder: - record = attr.ib(factory=list) +class TaskRecorder(_abc.Instrument): + record: list[tuple[str, Task | None]] = attr.ib(factory=list) - def before_run(self): - self.record.append(("before_run",)) + def before_run(self) -> None: + self.record.append(("before_run", None)) - def task_scheduled(self, task): + def task_scheduled(self, task: Task) -> None: self.record.append(("schedule", task)) - def before_task_step(self, task): + def before_task_step(self, task: Task) -> None: assert task is _core.current_task() self.record.append(("before", task)) - def after_task_step(self, task): + def after_task_step(self, task: Task) -> None: assert task is _core.current_task() self.record.append(("after", task)) - def after_run(self): - self.record.append(("after_run",)) + def after_run(self) -> None: + self.record.append(("after_run", None)) - def filter_tasks(self, tasks): + def filter_tasks(self, tasks: Container[Task]) -> Iterable[tuple[str, Task | None]]: for item in self.record: if item[0] in ("schedule", "before", "after") and item[1] in tasks: yield item @@ -34,7 +39,7 @@ def filter_tasks(self, tasks): yield item -def test_instruments(recwarn): +def test_instruments(recwarn: object) -> None: r1 = TaskRecorder() r2 = TaskRecorder() r3 = TaskRecorder() @@ -44,7 +49,7 @@ def test_instruments(recwarn): # We use a child task for this, because the main task does some extra # bookkeeping stuff that can leak into the instrument results, and we # don't want to deal with it. - async def task_fn(): + async def task_fn() -> None: nonlocal task task = _core.current_task() @@ -60,7 +65,7 @@ async def task_fn(): for _ in range(1): await _core.checkpoint() - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(task_fn) @@ -75,21 +80,22 @@ async def main(): + [("before", task), ("after", task), ("after_run",)] ) assert r1.record == r2.record + r3.record + assert task is not None assert list(r1.filter_tasks([task])) == expected -def test_instruments_interleave(): +def test_instruments_interleave() -> None: tasks = {} - async def two_step1(): + async def two_step1() -> None: tasks["t1"] = _core.current_task() await _core.checkpoint() - async def two_step2(): + async def two_step2() -> None: tasks["t2"] = _core.current_task() await _core.checkpoint() - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(two_step1) nursery.start_soon(two_step2) @@ -121,46 +127,46 @@ async def main(): check_sequence_matches(list(r.filter_tasks(tasks.values())), expected) -def test_null_instrument(): +def test_null_instrument() -> None: # undefined instrument methods are skipped - class NullInstrument: - def something_unrelated(self): + class NullInstrument(_abc.Instrument): + def something_unrelated(self) -> None: pass # pragma: no cover - async def main(): + async def main() -> None: await _core.checkpoint() _core.run(main, instruments=[NullInstrument()]) -def test_instrument_before_after_run(): +def test_instrument_before_after_run() -> None: record = [] - class BeforeAfterRun: - def before_run(self): + class BeforeAfterRun(_abc.Instrument): + def before_run(self) -> None: record.append("before_run") - def after_run(self): + def after_run(self) -> None: record.append("after_run") - async def main(): + async def main() -> None: pass _core.run(main, instruments=[BeforeAfterRun()]) assert record == ["before_run", "after_run"] -def test_instrument_task_spawn_exit(): +def test_instrument_task_spawn_exit() -> None: record = [] - class SpawnExitRecorder: - def task_spawned(self, task): + class SpawnExitRecorder(_abc.Instrument): + def task_spawned(self, task: Task) -> None: record.append(("spawned", task)) - def task_exited(self, task): + def task_exited(self, task: Task) -> None: record.append(("exited", task)) - async def main(): + async def main() -> Task: return _core.current_task() main_task = _core.run(main, instruments=[SpawnExitRecorder()]) @@ -170,20 +176,20 @@ async def main(): # This test also tests having a crash before the initial task is even spawned, # which is very difficult to handle. -def test_instruments_crash(caplog): +def test_instruments_crash(caplog: pytest.LogCaptureFixture) -> None: record = [] - class BrokenInstrument: - def task_scheduled(self, task): + class BrokenInstrument(_abc.Instrument): + def task_scheduled(self, task: Task) -> NoReturn: record.append("scheduled") raise ValueError("oops") - def close(self): + def close(self) -> None: # Shouldn't be called -- tests that the instrument disabling logic # works right. record.append("closed") # pragma: no cover - async def main(): + async def main() -> Task: record.append("main ran") return _core.current_task() @@ -195,24 +201,25 @@ async def main(): assert ("after", main_task) in r.record assert ("after_run",) in r.record # And we got a log message + assert caplog.records[0].exc_info is not None exc_type, exc_value, exc_traceback = caplog.records[0].exc_info assert exc_type is ValueError assert str(exc_value) == "oops" assert "Instrument has been disabled" in caplog.records[0].message -def test_instruments_monkeypatch(): +def test_instruments_monkeypatch() -> None: class NullInstrument(_abc.Instrument): pass instrument = NullInstrument() - async def main(): - record = [] + async def main() -> None: + record: list[Task] = [] # Changing the set of hooks implemented by an instrument after # it's installed doesn't make them start being called right away - instrument.before_task_step = record.append + instrument.before_task_step = record.append # type: ignore[assignment] await _core.checkpoint() await _core.checkpoint() assert len(record) == 0 @@ -233,16 +240,16 @@ async def main(): _core.run(main, instruments=[instrument]) -def test_instrument_that_raises_on_getattr(): - class EvilInstrument: - def task_exited(self, task): +def test_instrument_that_raises_on_getattr() -> None: + class EvilInstrument(_abc.Instrument): + def task_exited(self, task: Task) -> NoReturn: assert False # pragma: no cover @property - def after_run(self): + def after_run(self) -> NoReturn: raise ValueError("oops") - async def main(): + async def main() -> None: with pytest.raises(ValueError): _core.add_instrument(EvilInstrument()) diff --git a/trio/_core/_tests/test_ki.py b/trio/_core/_tests/test_ki.py index dc9f2f51e5..0fda688194 100644 --- a/trio/_core/_tests/test_ki.py +++ b/trio/_core/_tests/test_ki.py @@ -15,6 +15,7 @@ async_generator = yield_ = None from ... import _core +from ..._abc import Instrument from ..._timeouts import sleep from ..._util import signal_raise from ...testing import wait_all_tasks_blocked @@ -283,33 +284,33 @@ async def raiser(name, record): # simulated control-C during raiser, which is *unprotected* print("check 1") - record = set() + record_set: set[str] = set() async def check_unprotected_kill(): async with _core.open_nursery() as nursery: - nursery.start_soon(sleeper, "s1", record) - nursery.start_soon(sleeper, "s2", record) - nursery.start_soon(raiser, "r1", record) + nursery.start_soon(sleeper, "s1", record_set) + nursery.start_soon(sleeper, "s2", record_set) + nursery.start_soon(raiser, "r1", record_set) with pytest.raises(KeyboardInterrupt): _core.run(check_unprotected_kill) - assert record == {"s1 ok", "s2 ok", "r1 raise ok"} + assert record_set == {"s1 ok", "s2 ok", "r1 raise ok"} # simulated control-C during raiser, which is *protected*, so the KI gets # delivered to the main task instead print("check 2") - record = set() + record_set = set() async def check_protected_kill(): async with _core.open_nursery() as nursery: - nursery.start_soon(sleeper, "s1", record) - nursery.start_soon(sleeper, "s2", record) - nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record) + nursery.start_soon(sleeper, "s1", record_set) + nursery.start_soon(sleeper, "s2", record_set) + nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record_set) # __aexit__ blocks, and then receives the KI with pytest.raises(KeyboardInterrupt): _core.run(check_protected_kill) - assert record == {"s1 ok", "s2 ok", "r1 cancel ok"} + assert record_set == {"s1 ok", "s2 ok", "r1 cancel ok"} # kill at last moment still raises (run_sync_soon until it raises an # error, then kill) @@ -335,33 +336,33 @@ def kill_during_shutdown(): # KI arrives very early, before main is even spawned print("check 4") - class InstrumentOfDeath: - def before_run(self): + class InstrumentOfDeath(Instrument): + def before_run(self) -> None: ki_self() - async def main(): + async def main_1(): await _core.checkpoint() with pytest.raises(KeyboardInterrupt): - _core.run(main, instruments=[InstrumentOfDeath()]) + _core.run(main_1, instruments=[InstrumentOfDeath()]) # checkpoint_if_cancelled notices pending KI print("check 5") @_core.enable_ki_protection - async def main(): + async def main_2(): assert _core.currently_ki_protected() ki_self() with pytest.raises(KeyboardInterrupt): await _core.checkpoint_if_cancelled() - _core.run(main) + _core.run(main_2) # KI arrives while main task is not abortable, b/c already scheduled print("check 6") @_core.enable_ki_protection - async def main(): + async def main_3(): assert _core.currently_ki_protected() ki_self() await _core.cancel_shielded_checkpoint() @@ -370,13 +371,13 @@ async def main(): with pytest.raises(KeyboardInterrupt): await _core.checkpoint() - _core.run(main) + _core.run(main_3) # KI arrives while main task is not abortable, b/c refuses to be aborted print("check 7") @_core.enable_ki_protection - async def main(): + async def main_4(): assert _core.currently_ki_protected() ki_self() task = _core.current_task() @@ -389,13 +390,13 @@ def abort(_: RaiseCancelT) -> Abort: with pytest.raises(KeyboardInterrupt): await _core.checkpoint() - _core.run(main) + _core.run(main_4) # KI delivered via slow abort print("check 8") @_core.enable_ki_protection - async def main(): + async def main_5(): assert _core.currently_ki_protected() ki_self() task = _core.current_task() @@ -409,7 +410,7 @@ def abort(raise_cancel: RaiseCancelT) -> Abort: assert await _core.wait_task_rescheduled(abort) await _core.checkpoint() - _core.run(main) + _core.run(main_5) # KI arrives just before main task exits, so the run_sync_soon machinery # is still functioning and will accept the callback to deliver the KI, but @@ -418,42 +419,42 @@ def abort(raise_cancel: RaiseCancelT) -> Abort: print("check 9") @_core.enable_ki_protection - async def main(): + async def main_6(): ki_self() with pytest.raises(KeyboardInterrupt): - _core.run(main) + _core.run(main_6) print("check 10") # KI in unprotected code, with # restrict_keyboard_interrupt_to_checkpoints=True - record = [] + record_list = [] - async def main(): + async def main_7(): # We're not KI protected... assert not _core.currently_ki_protected() ki_self() # ...but even after the KI, we keep running uninterrupted... - record.append("ok") + record_list.append("ok") # ...until we hit a checkpoint: with pytest.raises(KeyboardInterrupt): await sleep(10) - _core.run(main, restrict_keyboard_interrupt_to_checkpoints=True) - assert record == ["ok"] - record = [] + _core.run(main_7, restrict_keyboard_interrupt_to_checkpoints=True) + assert record_list == ["ok"] + record_list = [] # Exact same code raises KI early if we leave off the argument, doesn't # even reach the record.append call: with pytest.raises(KeyboardInterrupt): - _core.run(main) - assert record == [] + _core.run(main_7) + assert record_list == [] # KI arrives while main task is inside a cancelled cancellation scope # the KeyboardInterrupt should take priority print("check 11") @_core.enable_ki_protection - async def main(): + async def main_8(): assert _core.currently_ki_protected() with _core.CancelScope() as cancel_scope: cancel_scope.cancel() @@ -465,7 +466,7 @@ async def main(): with pytest.raises(_core.Cancelled): await _core.checkpoint() - _core.run(main) + _core.run(main_8) def test_ki_is_good_neighbor(): @@ -488,16 +489,17 @@ async def main(): # Regression test for #461 +# don't know if _active not being visible is a problem def test_ki_with_broken_threads(): thread = threading.main_thread() # scary! - original = threading._active[thread.ident] + original = threading._active[thread.ident] # type: ignore[attr-defined] # put this in a try finally so we don't have a chance of cascading a # breakage down to everything else try: - del threading._active[thread.ident] + del threading._active[thread.ident] # type: ignore[attr-defined] @_core.enable_ki_protection async def inner(): @@ -505,4 +507,4 @@ async def inner(): _core.run(inner) finally: - threading._active[thread.ident] = original + threading._active[thread.ident] = original # type: ignore[attr-defined] diff --git a/trio/_core/_tests/test_local.py b/trio/_core/_tests/test_local.py index d36be0479e..5fdf54b13c 100644 --- a/trio/_core/_tests/test_local.py +++ b/trio/_core/_tests/test_local.py @@ -1,16 +1,19 @@ import pytest +from trio import run +from trio.lowlevel import RunVar, RunVarToken + from ... import _core # scary runvar tests -def test_runvar_smoketest(): - t1 = _core.RunVar("test1") - t2 = _core.RunVar("test2", default="catfish") +def test_runvar_smoketest() -> None: + t1 = RunVar[str]("test1") + t2 = RunVar[str]("test2", default="catfish") assert repr(t1) == "" - async def first_check(): + async def first_check() -> None: with pytest.raises(LookupError): t1.get() @@ -23,28 +26,28 @@ async def first_check(): assert t2.get() == "goldfish" assert t2.get(default="tuna") == "goldfish" - async def second_check(): + async def second_check() -> None: with pytest.raises(LookupError): t1.get() assert t2.get() == "catfish" - _core.run(first_check) - _core.run(second_check) + run(first_check) + run(second_check) -def test_runvar_resetting(): - t1 = _core.RunVar("test1") - t2 = _core.RunVar("test2", default="dogfish") - t3 = _core.RunVar("test3") +def test_runvar_resetting() -> None: + t1 = RunVar[str]("test1") + t2 = RunVar[str]("test2", default="dogfish") + t3 = RunVar[str]("test3") - async def reset_check(): + async def reset_check() -> None: token = t1.set("moonfish") assert t1.get() == "moonfish" t1.reset(token) with pytest.raises(TypeError): - t1.reset(None) + t1.reset(None) # type: ignore[arg-type] with pytest.raises(LookupError): t1.get() @@ -63,18 +66,18 @@ async def reset_check(): with pytest.raises(ValueError): t1.reset(token3) - _core.run(reset_check) + run(reset_check) -def test_runvar_sync(): - t1 = _core.RunVar("test1") +def test_runvar_sync() -> None: + t1 = RunVar[str]("test1") - async def sync_check(): - async def task1(): + async def sync_check() -> None: + async def task1() -> None: t1.set("plaice") assert t1.get() == "plaice" - async def task2(tok): + async def task2(tok: str) -> None: t1.reset(token) with pytest.raises(LookupError): @@ -94,11 +97,11 @@ async def task2(tok): await _core.wait_all_tasks_blocked() assert t1.get() == "haddock" - _core.run(sync_check) + run(sync_check) -def test_accessing_runvar_outside_run_call_fails(): - t1 = _core.RunVar("test1") +def test_accessing_runvar_outside_run_call_fails() -> None: + t1 = RunVar[str]("test1") with pytest.raises(RuntimeError): t1.set("asdf") @@ -106,10 +109,10 @@ def test_accessing_runvar_outside_run_call_fails(): with pytest.raises(RuntimeError): t1.get() - async def get_token(): + async def get_token() -> RunVarToken[str]: return t1.set("ok") - token = _core.run(get_token) + token = run(get_token) with pytest.raises(RuntimeError): t1.reset(token) diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py index de78443f4e..3cd79ecd8a 100644 --- a/trio/_core/_tests/test_thread_cache.py +++ b/trio/_core/_tests/test_thread_cache.py @@ -1,9 +1,13 @@ +from __future__ import annotations + import threading import time from contextlib import contextmanager from queue import Queue +from typing import NoReturn import pytest +from outcome import Outcome from .. import _thread_cache from .._thread_cache import ThreadCache, start_thread_soon @@ -11,9 +15,9 @@ def test_thread_cache_basics(): - q = Queue() + q = Queue[Outcome]() - def fn(): + def fn() -> NoReturn: raise RuntimeError("hi") def deliver(outcome): diff --git a/trio/_core/_tests/test_unbounded_queue.py b/trio/_core/_tests/test_unbounded_queue.py index cffeed1618..33eb41a5b3 100644 --- a/trio/_core/_tests/test_unbounded_queue.py +++ b/trio/_core/_tests/test_unbounded_queue.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools import pytest @@ -10,8 +12,8 @@ ) -async def test_UnboundedQueue_basic(): - q = _core.UnboundedQueue() +async def test_UnboundedQueue_basic() -> None: + q: _core.UnboundedQueue[str | int | None] = _core.UnboundedQueue() q.put_nowait("hi") assert await q.get_batch() == ["hi"] with pytest.raises(_core.WouldBlock): @@ -35,17 +37,17 @@ async def test_UnboundedQueue_basic(): repr(q) -async def test_UnboundedQueue_blocking(): +async def test_UnboundedQueue_blocking() -> None: record = [] - q = _core.UnboundedQueue() + q = _core.UnboundedQueue[int]() - async def get_batch_consumer(): + async def get_batch_consumer() -> None: while True: batch = await q.get_batch() assert batch record.append(batch) - async def aiter_consumer(): + async def aiter_consumer() -> None: async for batch in q: assert batch record.append(batch) @@ -67,8 +69,8 @@ async def aiter_consumer(): nursery.cancel_scope.cancel() -async def test_UnboundedQueue_fairness(): - q = _core.UnboundedQueue() +async def test_UnboundedQueue_fairness() -> None: + q = _core.UnboundedQueue[int]() # If there's no-one else around, we can put stuff in and take it out # again, no problem @@ -78,7 +80,7 @@ async def test_UnboundedQueue_fairness(): result = None - async def get_batch(q): + async def get_batch(q: _core.UnboundedQueue[int]) -> None: nonlocal result result = await q.get_batch() @@ -95,7 +97,7 @@ async def get_batch(q): # If two tasks are trying to read, they alternate record = [] - async def reader(name): + async def reader(name: str) -> None: while True: record.append((name, await q.get_batch())) @@ -114,8 +116,8 @@ async def reader(name): assert record == list(zip(itertools.cycle("ab"), [[i] for i in range(20)])) -async def test_UnboundedQueue_trivial_yields(): - q = _core.UnboundedQueue() +async def test_UnboundedQueue_trivial_yields() -> None: + q = _core.UnboundedQueue[None]() q.put_nowait(None) with assert_checkpoints(): @@ -127,17 +129,17 @@ async def test_UnboundedQueue_trivial_yields(): break -async def test_UnboundedQueue_no_spurious_wakeups(): +async def test_UnboundedQueue_no_spurious_wakeups() -> None: # If we have two tasks waiting, and put two items into the queue... then # only one task wakes up record = [] - async def getter(q, i): + async def getter(q: _core.UnboundedQueue[int], i: int) -> None: got = await q.get_batch() record.append((i, got)) async with _core.open_nursery() as nursery: - q = _core.UnboundedQueue() + q = _core.UnboundedQueue[int]() nursery.start_soon(getter, q, 1) await wait_all_tasks_blocked() nursery.start_soon(getter, q, 2) diff --git a/trio/lowlevel.py b/trio/lowlevel.py index 25e64975e2..964dabb556 100644 --- a/trio/lowlevel.py +++ b/trio/lowlevel.py @@ -15,6 +15,7 @@ RaiseCancelT as RaiseCancelT, RunStatistics as RunStatistics, RunVar as RunVar, + RunVarToken as RunVarToken, Task as Task, TrioToken as TrioToken, UnboundedQueue as UnboundedQueue, From 8eb11900a7e4ba59e668790f7c161903ef352ea2 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 15 Oct 2023 15:22:32 +0200 Subject: [PATCH 02/35] pyannotate --safe --- trio/_tests/test_exports.py | 6 +- trio/_tests/test_file_io.py | 38 +++--- trio/_tests/test_highlevel_generic.py | 16 +-- trio/_tests/test_highlevel_open_tcp_stream.py | 2 +- .../_tests/test_highlevel_open_unix_stream.py | 8 +- trio/_tests/test_highlevel_serve_listeners.py | 20 +-- trio/_tests/test_highlevel_socket.py | 2 +- trio/_tests/test_highlevel_ssl_helpers.py | 8 +- trio/_tests/test_path.py | 48 +++---- trio/_tests/test_scheduler_determinism.py | 6 +- trio/_tests/test_socket.py | 2 +- trio/_tests/test_ssl.py | 126 +++++++++--------- trio/_tests/test_subprocess.py | 56 ++++---- trio/_tests/test_sync.py | 62 ++++----- trio/_tests/test_testing.py | 90 ++++++------- trio/_tests/test_threads.py | 116 ++++++++-------- trio/_tests/test_timeouts.py | 18 +-- trio/_tests/test_tracing.py | 4 +- trio/_tests/test_unix_pipes.py | 36 ++--- trio/_tests/test_util.py | 26 ++-- trio/_tests/test_wait_for_object.py | 10 +- trio/_tests/test_windows_pipes.py | 18 +-- trio/_tests/tools/test_gen_exports.py | 6 +- 23 files changed, 363 insertions(+), 361 deletions(-) diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py index c3d8a03b63..7a016e86d3 100644 --- a/trio/_tests/test_exports.py +++ b/trio/_tests/test_exports.py @@ -34,7 +34,7 @@ Protocol_ext = Protocol # type: ignore[assignment] -def _ensure_mypy_cache_updated(): +def _ensure_mypy_cache_updated() -> None: # This pollutes the `empty` dir. Should this be changed? try: from mypy.api import run @@ -59,7 +59,7 @@ def _ensure_mypy_cache_updated(): mypy_cache_updated = True -def test_core_is_properly_reexported(): +def test_core_is_properly_reexported() -> None: # Each export from _core should be re-exported by exactly one of these # three modules: sources = [trio, trio.lowlevel, trio.testing] @@ -126,7 +126,7 @@ def iter_modules( # https://github.com/pypa/setuptools/issues/3274 "ignore:module 'sre_constants' is deprecated:DeprecationWarning", ) -def test_static_tool_sees_all_symbols(tool, modname, tmpdir): +def test_static_tool_sees_all_symbols(tool, modname, tmpdir) -> None: module = importlib.import_module(modname) def no_underscores(symbols): diff --git a/trio/_tests/test_file_io.py b/trio/_tests/test_file_io.py index bae426cf48..863ebe81b0 100644 --- a/trio/_tests/test_file_io.py +++ b/trio/_tests/test_file_io.py @@ -28,17 +28,17 @@ def async_file(wrapped): return trio.wrap_file(wrapped) -def test_wrap_invalid(): +def test_wrap_invalid() -> None: with pytest.raises(TypeError): trio.wrap_file("") -def test_wrap_non_iobase(): +def test_wrap_non_iobase() -> None: class FakeFile: - def close(self): # pragma: no cover + def close(self) -> None: # pragma: no cover pass - def write(self): # pragma: no cover + def write(self) -> None: # pragma: no cover pass wrapped = FakeFile() @@ -53,11 +53,11 @@ def write(self): # pragma: no cover trio.wrap_file(FakeFile()) -def test_wrapped_property(async_file, wrapped): +def test_wrapped_property(async_file, wrapped) -> None: assert async_file.wrapped is wrapped -def test_dir_matches_wrapped(async_file, wrapped): +def test_dir_matches_wrapped(async_file, wrapped) -> None: attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS) # all supported attrs in wrapped should be available in async_file @@ -68,9 +68,9 @@ def test_dir_matches_wrapped(async_file, wrapped): ) -def test_unsupported_not_forwarded(): +def test_unsupported_not_forwarded() -> None: class FakeFile(io.RawIOBase): - def unsupported_attr(self): # pragma: no cover + def unsupported_attr(self) -> None: # pragma: no cover pass async_file = trio.wrap_file(FakeFile()) @@ -121,7 +121,7 @@ def test_type_stubs_match_lists() -> None: assert found == expected -def test_sync_attrs_forwarded(async_file, wrapped): +def test_sync_attrs_forwarded(async_file, wrapped) -> None: for attr_name in _FILE_SYNC_ATTRS: if attr_name not in dir(async_file): continue @@ -129,7 +129,7 @@ def test_sync_attrs_forwarded(async_file, wrapped): assert getattr(async_file, attr_name) is getattr(wrapped, attr_name) -def test_sync_attrs_match_wrapper(async_file, wrapped): +def test_sync_attrs_match_wrapper(async_file, wrapped) -> None: for attr_name in _FILE_SYNC_ATTRS: if attr_name in dir(async_file): continue @@ -141,7 +141,7 @@ def test_sync_attrs_match_wrapper(async_file, wrapped): getattr(wrapped, attr_name) -def test_async_methods_generated_once(async_file): +def test_async_methods_generated_once(async_file) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name not in dir(async_file): continue @@ -149,7 +149,7 @@ def test_async_methods_generated_once(async_file): assert getattr(async_file, meth_name) is getattr(async_file, meth_name) -def test_async_methods_signature(async_file): +def test_async_methods_signature(async_file) -> None: # use read as a representative of all async methods assert async_file.read.__name__ == "read" assert async_file.read.__qualname__ == "AsyncIOWrapper.read" @@ -157,7 +157,7 @@ def test_async_methods_signature(async_file): assert "io.StringIO.read" in async_file.read.__doc__ -async def test_async_methods_wrap(async_file, wrapped): +async def test_async_methods_wrap(async_file, wrapped) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name not in dir(async_file): continue @@ -175,7 +175,7 @@ async def test_async_methods_wrap(async_file, wrapped): wrapped.reset_mock() -async def test_async_methods_match_wrapper(async_file, wrapped): +async def test_async_methods_match_wrapper(async_file, wrapped) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name in dir(async_file): continue @@ -187,7 +187,7 @@ async def test_async_methods_match_wrapper(async_file, wrapped): getattr(wrapped, meth_name) -async def test_open(path): +async def test_open(path) -> None: f = await trio.open_file(path, "w") assert isinstance(f, AsyncIOWrapper) @@ -195,7 +195,7 @@ async def test_open(path): await f.aclose() -async def test_open_context_manager(path): +async def test_open_context_manager(path) -> None: async with await trio.open_file(path, "w") as f: assert isinstance(f, AsyncIOWrapper) assert not f.closed @@ -203,7 +203,7 @@ async def test_open_context_manager(path): assert f.closed -async def test_async_iter(): +async def test_async_iter() -> None: async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar")) expected = list(async_file.wrapped) result = [] @@ -215,7 +215,7 @@ async def test_async_iter(): assert result == expected -async def test_aclose_cancelled(path): +async def test_aclose_cancelled(path) -> None: with _core.CancelScope() as cscope: f = await trio.open_file(path, "w") cscope.cancel() @@ -229,7 +229,7 @@ async def test_aclose_cancelled(path): assert f.closed -async def test_detach_rewraps_asynciobase(): +async def test_detach_rewraps_asynciobase() -> None: raw = io.BytesIO() buffered = io.BufferedReader(raw) diff --git a/trio/_tests/test_highlevel_generic.py b/trio/_tests/test_highlevel_generic.py index 38bcedee25..64c5697184 100644 --- a/trio/_tests/test_highlevel_generic.py +++ b/trio/_tests/test_highlevel_generic.py @@ -9,13 +9,13 @@ class RecordSendStream(SendStream): record = attr.ib(factory=list) - async def send_all(self, data): + async def send_all(self, data) -> None: self.record.append(("send_all", data)) - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: self.record.append("wait_send_all_might_not_block") - async def aclose(self): + async def aclose(self) -> None: self.record.append("aclose") @@ -23,14 +23,14 @@ async def aclose(self): class RecordReceiveStream(ReceiveStream): record = attr.ib(factory=list) - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes=None) -> None: self.record.append(("receive_some", max_bytes)) - async def aclose(self): + async def aclose(self) -> None: self.record.append("aclose") -async def test_StapledStream(): +async def test_StapledStream() -> None: send_stream = RecordSendStream() receive_stream = RecordReceiveStream() stapled = StapledStream(send_stream, receive_stream) @@ -50,7 +50,7 @@ async def test_StapledStream(): assert send_stream.record == ["aclose"] send_stream.record.clear() - async def fake_send_eof(): + async def fake_send_eof() -> None: send_stream.record.append("send_eof") send_stream.send_eof = fake_send_eof @@ -70,7 +70,7 @@ async def fake_send_eof(): assert send_stream.record == ["aclose"] -async def test_StapledStream_with_erroring_close(): +async def test_StapledStream_with_erroring_close() -> None: # Make sure that if one of the aclose methods errors out, then the other # one still gets called. class BrokenSendStream(RecordSendStream): diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py index f875bfa019..7917e66c50 100644 --- a/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/trio/_tests/test_highlevel_open_tcp_stream.py @@ -244,7 +244,7 @@ def __init__( port: int, ip_list: Sequence[tuple[str, float, str]], supported_families: set[AddressFamily], - ): + ) -> None: # ip_list have to be unique ip_order = [ip for (ip, _, _) in ip_list] assert len(set(ip_order)) == len(ip_list) diff --git a/trio/_tests/test_highlevel_open_unix_stream.py b/trio/_tests/test_highlevel_open_unix_stream.py index 64a15f9e9d..cf32b8c0fc 100644 --- a/trio/_tests/test_highlevel_open_unix_stream.py +++ b/trio/_tests/test_highlevel_open_unix_stream.py @@ -15,7 +15,7 @@ def test_close_on_error(): class CloseMe: closed = False - def close(self): + def close(self) -> None: self.closed = True with close_on_error(CloseMe()) as c: @@ -29,12 +29,12 @@ def close(self): @pytest.mark.parametrize("filename", [4, 4.5]) -async def test_open_with_bad_filename_type(filename): +async def test_open_with_bad_filename_type(filename) -> None: with pytest.raises(TypeError): await open_unix_socket(filename) -async def test_open_bad_socket(): +async def test_open_bad_socket() -> None: # mktemp is marked as insecure, but that's okay, we don't want the file to # exist name = tempfile.mktemp() @@ -42,7 +42,7 @@ async def test_open_bad_socket(): await open_unix_socket(name) -async def test_open_unix_socket(): +async def test_open_unix_socket() -> None: for name_type in [Path, str]: name = tempfile.mktemp() serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py index 65804f4222..beac59f86d 100644 --- a/trio/_tests/test_highlevel_serve_listeners.py +++ b/trio/_tests/test_highlevel_serve_listeners.py @@ -32,33 +32,33 @@ async def accept(self): self.accepted_streams.append(stream) return stream - async def aclose(self): + async def aclose(self) -> None: self.closed = True await trio.lowlevel.checkpoint() -async def test_serve_listeners_basic(): +async def test_serve_listeners_basic() -> None: listeners = [MemoryListener(), MemoryListener()] record = [] - def close_hook(): + def close_hook() -> None: # Make sure this is a forceful close assert trio.current_effective_deadline() == float("-inf") record.append("closed") - async def handler(stream): + async def handler(stream) -> None: await stream.send_all(b"123") assert await stream.receive_some(10) == b"456" stream.send_stream.close_hook = close_hook stream.receive_stream.close_hook = close_hook - async def client(listener): + async def client(listener) -> None: s = await listener.connect() assert await s.receive_some(10) == b"123" await s.send_all(b"456") - async def do_tests(parent_nursery): + async def do_tests(parent_nursery) -> None: async with trio.open_nursery() as nursery: for listener in listeners: for _ in range(3): @@ -82,7 +82,7 @@ async def do_tests(parent_nursery): assert listener.closed -async def test_serve_listeners_accept_unrecognized_error(): +async def test_serve_listeners_accept_unrecognized_error() -> None: for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]: listener = MemoryListener() @@ -96,7 +96,7 @@ async def raise_error(): assert excinfo.value is error -async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog): +async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog) -> None: listener = MemoryListener() async def raise_EMFILE(): @@ -115,10 +115,10 @@ async def raise_EMFILE(): assert record.exc_info[1].errno == errno.EMFILE -async def test_serve_listeners_connection_nursery(autojump_clock): +async def test_serve_listeners_connection_nursery(autojump_clock) -> None: listener = MemoryListener() - async def handler(stream): + async def handler(stream) -> None: await trio.sleep(1) class Done(Exception): diff --git a/trio/_tests/test_highlevel_socket.py b/trio/_tests/test_highlevel_socket.py index 514b6ad196..61e891e94b 100644 --- a/trio/_tests/test_highlevel_socket.py +++ b/trio/_tests/test_highlevel_socket.py @@ -211,7 +211,7 @@ async def test_SocketListener_socket_closed_underfoot() -> None: async def test_SocketListener_accept_errors() -> None: class FakeSocket(tsocket.SocketType): - def __init__(self, events: Sequence[SocketType | BaseException]): + def __init__(self, events: Sequence[SocketType | BaseException]) -> None: self._events = iter(events) type = tsocket.SOCK_STREAM diff --git a/trio/_tests/test_highlevel_ssl_helpers.py b/trio/_tests/test_highlevel_ssl_helpers.py index 89d921476a..9bec4ae2f5 100644 --- a/trio/_tests/test_highlevel_ssl_helpers.py +++ b/trio/_tests/test_highlevel_ssl_helpers.py @@ -17,7 +17,7 @@ from .test_ssl import SERVER_CTX, client_ctx # noqa: F401 -async def echo_handler(stream): +async def echo_handler(stream) -> None: async with stream: try: while True: @@ -44,7 +44,9 @@ async def getnameinfo(self, *args): # pragma: no cover # This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners... # using noqa because linters don't understand how pytest fixtures work. -async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa: F811 +async def test_open_ssl_over_tcp_stream_and_everything_else( + client_ctx, # noqa: F811 # linters doesn't understand fixture +) -> None: async with trio.open_nursery() as nursery: (listener,) = await nursery.start( partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1") @@ -97,7 +99,7 @@ async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa nursery.cancel_scope.cancel() -async def test_open_ssl_over_tcp_listeners(): +async def test_open_ssl_over_tcp_listeners() -> None: (listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1") async with listener: assert isinstance(listener, trio.SSLListener) diff --git a/trio/_tests/test_path.py b/trio/_tests/test_path.py index bfef1aaf2c..1d17029bdf 100644 --- a/trio/_tests/test_path.py +++ b/trio/_tests/test_path.py @@ -20,14 +20,14 @@ def method_pair(path, method_name): return getattr(path, method_name), getattr(async_path, method_name) -async def test_open_is_async_context_manager(path): +async def test_open_is_async_context_manager(path) -> None: async with await path.open("w") as f: assert isinstance(f, AsyncIOWrapper) assert f.closed -async def test_magic(): +async def test_magic() -> None: path = trio.Path("test") assert str(path) == "test" @@ -42,7 +42,7 @@ async def test_magic(): @pytest.mark.parametrize("cls_a,cls_b", cls_pairs) -async def test_cmp_magic(cls_a, cls_b): +async def test_cmp_magic(cls_a, cls_b) -> None: a, b = cls_a(""), cls_b("") assert a == b assert not a != b @@ -69,7 +69,7 @@ async def test_cmp_magic(cls_a, cls_b): @pytest.mark.parametrize("cls_a,cls_b", cls_pairs) -async def test_div_magic(cls_a, cls_b): +async def test_div_magic(cls_a, cls_b) -> None: a, b = cls_a("a"), cls_b("b") result = a / b @@ -81,19 +81,19 @@ async def test_div_magic(cls_a, cls_b): "cls_a,cls_b", [(trio.Path, pathlib.Path), (trio.Path, trio.Path)] ) @pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"]) -async def test_hash_magic(cls_a, cls_b, path): +async def test_hash_magic(cls_a, cls_b, path) -> None: a, b = cls_a(path), cls_b(path) assert hash(a) == hash(b) -async def test_forwarded_properties(path): +async def test_forwarded_properties(path) -> None: # use `name` as a representative of forwarded properties assert "name" in dir(path) assert path.name == "test" -async def test_async_method_signature(path): +async def test_async_method_signature(path) -> None: # use `resolve` as a representative of wrapped methods assert path.resolve.__name__ == "resolve" @@ -103,7 +103,7 @@ async def test_async_method_signature(path): @pytest.mark.parametrize("method_name", ["is_dir", "is_file"]) -async def test_compare_async_stat_methods(method_name): +async def test_compare_async_stat_methods(method_name) -> None: method, async_method = method_pair(".", method_name) result = method() @@ -112,13 +112,13 @@ async def test_compare_async_stat_methods(method_name): assert result == async_result -async def test_invalid_name_not_wrapped(path): +async def test_invalid_name_not_wrapped(path) -> None: with pytest.raises(AttributeError): getattr(path, "invalid_fake_attr") @pytest.mark.parametrize("method_name", ["absolute", "resolve"]) -async def test_async_methods_rewrap(method_name): +async def test_async_methods_rewrap(method_name) -> None: method, async_method = method_pair(".", method_name) result = method() @@ -128,7 +128,7 @@ async def test_async_methods_rewrap(method_name): assert str(result) == str(async_result) -async def test_forward_methods_rewrap(path, tmpdir): +async def test_forward_methods_rewrap(path, tmpdir) -> None: with_name = path.with_name("foo") with_suffix = path.with_suffix(".py") @@ -138,17 +138,17 @@ async def test_forward_methods_rewrap(path, tmpdir): assert with_suffix == tmpdir.join("test.py") -async def test_forward_properties_rewrap(path): +async def test_forward_properties_rewrap(path) -> None: assert isinstance(path.parent, trio.Path) -async def test_forward_methods_without_rewrap(path, tmpdir): +async def test_forward_methods_without_rewrap(path, tmpdir) -> None: path = await path.parent.resolve() assert path.as_uri().startswith("file:///") -async def test_repr(): +async def test_repr() -> None: path = trio.Path(".") assert repr(path) == "trio.Path('.')" @@ -164,30 +164,30 @@ class MockWrapper: _wraps = MockWrapped -async def test_type_forwards_unsupported(): +async def test_type_forwards_unsupported() -> None: with pytest.raises(TypeError): Type.generate_forwards(MockWrapper, {}) -async def test_type_wraps_unsupported(): +async def test_type_wraps_unsupported() -> None: with pytest.raises(TypeError): Type.generate_wraps(MockWrapper, {}) -async def test_type_forwards_private(): +async def test_type_forwards_private() -> None: Type.generate_forwards(MockWrapper, {"unsupported": None}) assert not hasattr(MockWrapper, "_private") -async def test_type_wraps_private(): +async def test_type_wraps_private() -> None: Type.generate_wraps(MockWrapper, {"unsupported": None}) assert not hasattr(MockWrapper, "_private") @pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath]) -async def test_path_wraps_path(path, meth): +async def test_path_wraps_path(path, meth) -> None: wrapped = await path.absolute() result = meth(path, wrapped) if result is None: @@ -196,17 +196,17 @@ async def test_path_wraps_path(path, meth): assert wrapped == result -async def test_path_nonpath(): +async def test_path_nonpath() -> None: with pytest.raises(TypeError): trio.Path(1) -async def test_open_file_can_open_path(path): +async def test_open_file_can_open_path(path) -> None: async with await trio.open_file(path, "w") as f: assert f.name == os.fspath(path) -async def test_globmethods(path): +async def test_globmethods(path) -> None: # Populate a directory tree await path.mkdir() await (path / "foo").mkdir() @@ -235,7 +235,7 @@ async def test_globmethods(path): assert entries == {"_bar.txt", "bar.txt"} -async def test_iterdir(path): +async def test_iterdir(path) -> None: # Populate a directory await path.mkdir() await (path / "foo").mkdir() @@ -249,7 +249,7 @@ async def test_iterdir(path): assert entries == {"bar.txt", "foo"} -async def test_classmethods(): +async def test_classmethods() -> None: assert isinstance(await trio.Path.home(), trio.Path) # pathlib.Path has only two classmethods diff --git a/trio/_tests/test_scheduler_determinism.py b/trio/_tests/test_scheduler_determinism.py index e2d3167e45..1c438f136c 100644 --- a/trio/_tests/test_scheduler_determinism.py +++ b/trio/_tests/test_scheduler_determinism.py @@ -5,7 +5,7 @@ async def scheduler_trace(): """Returns a scheduler-dependent value we can use to check determinism.""" trace = [] - async def tracer(name): + async def tracer(name) -> None: for i in range(50): trace.append((name, i)) await trio.sleep(0) @@ -17,7 +17,7 @@ async def tracer(name): return tuple(trace) -def test_the_trio_scheduler_is_not_deterministic(): +def test_the_trio_scheduler_is_not_deterministic() -> None: # At least, not yet. See https://github.com/python-trio/trio/issues/32 traces = [] for _ in range(10): @@ -25,7 +25,7 @@ def test_the_trio_scheduler_is_not_deterministic(): assert len(set(traces)) == len(traces) -def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch): +def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch) -> None: monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True) traces = [] for _ in range(10): diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py index 40ffefb8cd..6b1753c82b 100644 --- a/trio/_tests/test_socket.py +++ b/trio/_tests/test_socket.py @@ -39,7 +39,7 @@ class MonkeypatchedGAI: - def __init__(self, orig_getaddrinfo: Callable[..., getaddrinfoResponse]): + def __init__(self, orig_getaddrinfo: Callable[..., getaddrinfoResponse]) -> None: self._orig_getaddrinfo = orig_getaddrinfo self._responses: dict[tuple[Any, ...], getaddrinfoResponse | str] = {} self.record: list[tuple[Any, ...]] = [] diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py index d92e35b0d9..06ac2e694c 100644 --- a/trio/_tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -167,7 +167,7 @@ async def ssl_echo_server(client_ctx, **kwargs): # Doesn't inherit from Stream because I left out the methods that we don't # actually need. class PyOpenSSLEchoStream: - def __init__(self, sleeper=None): + def __init__(self, sleeper=None) -> None: ctx = SSL.Context(SSL.SSLv23_METHOD) # TLS 1.3 removes renegotiation support. Which is great for them, but # we still have to support versions before that, and that means we @@ -217,31 +217,31 @@ def __init__(self, sleeper=None): if sleeper is None: - async def no_op_sleeper(_): + async def no_op_sleeper(_) -> None: return self.sleeper = no_op_sleeper else: self.sleeper = sleeper - async def aclose(self): + async def aclose(self) -> None: self._conn.bio_shutdown() def renegotiate_pending(self): return self._conn.renegotiate_pending() - def renegotiate(self): + def renegotiate(self) -> None: # Returns false if a renegotiation is already in progress, meaning # nothing happens. assert self._conn.renegotiate() - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: with self._send_all_conflict_detector: await _core.checkpoint() await _core.checkpoint() await self.sleeper("wait_send_all_might_not_block") - async def send_all(self, data): + async def send_all(self, data) -> None: print(" --> transport_stream.send_all") with self._send_all_conflict_detector: await _core.checkpoint() @@ -320,7 +320,7 @@ async def receive_some(self, nbytes=None): print(" <-- transport_stream.receive_some finished") -async def test_PyOpenSSLEchoStream_gives_resource_busy_errors(): +async def test_PyOpenSSLEchoStream_gives_resource_busy_errors() -> None: # Make sure that PyOpenSSLEchoStream complains if two tasks call send_all # at the same time, or ditto for receive_some. The tricky cases where SSLStream # might accidentally do this are during renegotiation, which we test using @@ -395,7 +395,7 @@ def ssl_lockstep_stream_pair(client_ctx, **kwargs): # Simple smoke test for handshake/send/receive/shutdown talking to a # synchronous server, plus make sure that we do the bare minimum of # certificate checking (even though this is really Python's responsibility) -async def test_ssl_client_basics(client_ctx): +async def test_ssl_client_basics(client_ctx) -> None: # Everything OK async with ssl_echo_server(client_ctx) as s: assert not s.server_side @@ -421,7 +421,7 @@ async def test_ssl_client_basics(client_ctx): assert isinstance(excinfo.value.__cause__, ssl.CertificateError) -async def test_ssl_server_basics(client_ctx): +async def test_ssl_server_basics(client_ctx) -> None: a, b = stdlib_socket.socketpair() with a, b: server_sock = tsocket.from_stdlib_socket(b) @@ -430,7 +430,7 @@ async def test_ssl_server_basics(client_ctx): ) assert server_transport.server_side - def client(): + def client() -> None: with client_ctx.wrap_socket( a, server_hostname="trio-test-1.example.org" ) as client_sock: @@ -451,7 +451,7 @@ def client(): t.join() -async def test_attributes(client_ctx): +async def test_attributes(client_ctx) -> None: async with ssl_echo_server_raw(expect_fail=True) as sock: good_ctx = client_ctx bad_ctx = ssl.create_default_context() @@ -520,7 +520,7 @@ async def test_attributes(client_ctx): # I begin to see why HTTP/2 forbids renegotiation and TLS 1.3 removes it... -async def test_full_duplex_basics(client_ctx): +async def test_full_duplex_basics(client_ctx) -> None: CHUNKS = 30 CHUNK_SIZE = 32768 EXPECTED = CHUNKS * CHUNK_SIZE @@ -528,7 +528,7 @@ async def test_full_duplex_basics(client_ctx): sent = bytearray() received = bytearray() - async def sender(s): + async def sender(s) -> None: nonlocal sent for i in range(CHUNKS): print(i) @@ -536,7 +536,7 @@ async def sender(s): sent += chunk await s.send_all(chunk) - async def receiver(s): + async def receiver(s) -> None: nonlocal received while len(received) < EXPECTED: chunk = await s.receive_some(CHUNK_SIZE // 2) @@ -557,7 +557,7 @@ async def receiver(s): assert sent == received -async def test_renegotiation_simple(client_ctx): +async def test_renegotiation_simple(client_ctx) -> None: with virtual_ssl_echo_server(client_ctx) as s: await s.do_handshake() @@ -576,7 +576,7 @@ async def test_renegotiation_simple(client_ctx): @slow -async def test_renegotiation_randomized(mock_clock, client_ctx): +async def test_renegotiation_randomized(mock_clock, client_ctx) -> None: # The only blocking things in this function are our random sleeps, so 0 is # a good threshold. mock_clock.autojump_threshold = 0 @@ -585,10 +585,10 @@ async def test_renegotiation_randomized(mock_clock, client_ctx): r = random.Random(0) - async def sleeper(_): + async def sleeper(_) -> None: await trio.sleep(r.uniform(0, 10)) - async def clear(): + async def clear() -> None: while s.transport_stream.renegotiate_pending(): with assert_checkpoints(): await send(b"-") @@ -596,13 +596,13 @@ async def clear(): await expect(b"-") print("-- clear --") - async def send(byte): + async def send(byte) -> None: await s.transport_stream.sleeper("outer send") print("calling SSLStream.send_all", byte) with assert_checkpoints(): await s.send_all(byte) - async def expect(expected): + async def expect(expected) -> None: await s.transport_stream.sleeper("expect") print("calling SSLStream.receive_some, expecting", expected) assert len(expected) == 1 @@ -648,13 +648,13 @@ async def expect(expected): # and wait_send_all_might_not_block comes in. # Our receive_some() call will get stuck when it hits send_all - async def sleeper_with_slow_send_all(method): + async def sleeper_with_slow_send_all(method) -> None: if method == "send_all": await trio.sleep(100000) # And our wait_send_all_might_not_block call will give it time to get # stuck, and then start - async def sleep_then_wait_writable(): + async def sleep_then_wait_writable() -> None: await trio.sleep(1000) await s.wait_send_all_might_not_block() @@ -672,7 +672,7 @@ async def sleep_then_wait_writable(): # 2) Same, but now wait_send_all_might_not_block is stuck when # receive_some tries to send. - async def sleeper_with_slow_wait_writable_and_expect(method): + async def sleeper_with_slow_wait_writable_and_expect(method) -> None: if method == "wait_send_all_might_not_block": await trio.sleep(100000) elif method == "expect": @@ -692,16 +692,16 @@ async def sleeper_with_slow_wait_writable_and_expect(method): await s.aclose() -async def test_resource_busy_errors(client_ctx): - async def do_send_all(): +async def test_resource_busy_errors(client_ctx) -> None: + async def do_send_all() -> None: with assert_checkpoints(): await s.send_all(b"x") - async def do_receive_some(): + async def do_receive_some() -> None: with assert_checkpoints(): await s.receive_some(1) - async def do_wait_send_all_might_not_block(): + async def do_wait_send_all_might_not_block() -> None: with assert_checkpoints(): await s.wait_send_all_might_not_block() @@ -734,11 +734,11 @@ async def do_wait_send_all_might_not_block(): assert "another task" in str(excinfo.value) -async def test_wait_writable_calls_underlying_wait_writable(): +async def test_wait_writable_calls_underlying_wait_writable() -> None: record = [] class NotAStream: - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: record.append("ok") ctx = ssl.create_default_context() @@ -751,7 +751,7 @@ async def wait_send_all_might_not_block(self): os.name == "nt" and sys.version_info >= (3, 10), reason="frequently fails on Windows + Python 3.10", ) -async def test_checkpoints(client_ctx): +async def test_checkpoints(client_ctx) -> None: async with ssl_echo_server(client_ctx) as s: with assert_checkpoints(): await s.do_handshake() @@ -780,7 +780,7 @@ async def test_checkpoints(client_ctx): await s.aclose() -async def test_send_all_empty_string(client_ctx): +async def test_send_all_empty_string(client_ctx) -> None: async with ssl_echo_server(client_ctx) as s: await s.do_handshake() @@ -797,7 +797,7 @@ async def test_send_all_empty_string(client_ctx): @pytest.mark.parametrize("https_compatible", [False, True]) -async def test_SSLStream_generic(client_ctx, https_compatible): +async def test_SSLStream_generic(client_ctx, https_compatible) -> None: async def stream_maker(): return ssl_memory_stream_pair( client_ctx, @@ -821,14 +821,14 @@ async def clogged_stream_maker(): await check_two_way_stream(stream_maker, clogged_stream_maker) -async def test_unwrap(client_ctx): +async def test_unwrap(client_ctx) -> None: client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) client_transport = client_ssl.transport_stream server_transport = server_ssl.transport_stream seq = Sequencer() - async def client(): + async def client() -> None: await client_ssl.do_handshake() await client_ssl.send_all(b"x") assert await client_ssl.receive_some(1) == b"y" @@ -855,7 +855,7 @@ async def client(): client_transport.send_stream.send_all_hook = send_all_hook await client_transport.send_stream.send_all_hook() - async def server(): + async def server() -> None: await server_ssl.do_handshake() assert await server_ssl.receive_some(1) == b"x" await server_ssl.send_all(b"y") @@ -875,7 +875,7 @@ async def server(): nursery.start_soon(server) -async def test_closing_nice_case(client_ctx): +async def test_closing_nice_case(client_ctx) -> None: # the nice case: graceful closes all around client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) @@ -883,11 +883,11 @@ async def test_closing_nice_case(client_ctx): # Both the handshake and the close require back-and-forth discussion, so # we need to run them concurrently - async def client_closer(): + async def client_closer() -> None: with assert_checkpoints(): await client_ssl.aclose() - async def server_closer(): + async def server_closer() -> None: assert await server_ssl.receive_some(10) == b"" assert await server_ssl.receive_some(10) == b"" with assert_checkpoints(): @@ -926,7 +926,7 @@ async def server_closer(): # the other side client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) - async def expect_eof_server(): + async def expect_eof_server() -> None: with assert_checkpoints(): assert await server_ssl.receive_some(10) == b"" with assert_checkpoints(): @@ -937,7 +937,7 @@ async def expect_eof_server(): nursery.start_soon(expect_eof_server) -async def test_send_all_fails_in_the_middle(client_ctx): +async def test_send_all_fails_in_the_middle(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) async with _core.open_nursery() as nursery: @@ -957,7 +957,7 @@ async def bad_hook(): closed = 0 - def close_hook(): + def close_hook() -> None: nonlocal closed closed += 1 @@ -968,7 +968,7 @@ def close_hook(): assert closed == 2 -async def test_ssl_over_ssl(client_ctx): +async def test_ssl_over_ssl(client_ctx) -> None: client_0, server_0 = memory_stream_pair() client_1 = SSLStream( @@ -981,11 +981,11 @@ async def test_ssl_over_ssl(client_ctx): ) server_2 = SSLStream(server_1, SERVER_CTX, server_side=True) - async def client(): + async def client() -> None: await client_2.send_all(b"hi") assert await client_2.receive_some(10) == b"bye" - async def server(): + async def server() -> None: assert await server_2.receive_some(10) == b"hi" await server_2.send_all(b"bye") @@ -994,7 +994,7 @@ async def server(): nursery.start_soon(server) -async def test_ssl_bad_shutdown(client_ctx): +async def test_ssl_bad_shutdown(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) async with _core.open_nursery() as nursery: @@ -1011,7 +1011,7 @@ async def test_ssl_bad_shutdown(client_ctx): await server.aclose() -async def test_ssl_bad_shutdown_but_its_ok(client_ctx): +async def test_ssl_bad_shutdown_but_its_ok(client_ctx) -> None: client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": True}, @@ -1031,7 +1031,7 @@ async def test_ssl_bad_shutdown_but_its_ok(client_ctx): await server.aclose() -async def test_ssl_handshake_failure_during_aclose(): +async def test_ssl_handshake_failure_during_aclose() -> None: # Weird scenario: aclose() triggers an automatic handshake, and this # fails. This also exercises a bit of code in aclose() that was otherwise # uncovered, for re-raising exceptions after calling aclose_forcefully on @@ -1050,7 +1050,7 @@ async def test_ssl_handshake_failure_during_aclose(): await s.aclose() -async def test_ssl_only_closes_stream_once(client_ctx): +async def test_ssl_only_closes_stream_once(client_ctx) -> None: # We used to have a bug where if transport_stream.aclose() raised an # error, we would call it again. This checks that that's fixed. client, server = ssl_memory_stream_pair(client_ctx) @@ -1075,7 +1075,7 @@ def close_hook(): assert transport_close_count == 1 -async def test_ssl_https_compatibility_disagreement(client_ctx): +async def test_ssl_https_compatibility_disagreement(client_ctx) -> None: client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": False}, @@ -1088,7 +1088,7 @@ async def test_ssl_https_compatibility_disagreement(client_ctx): # client is in HTTPS-mode, server is not # so client doing graceful_shutdown causes an error on server - async def receive_and_expect_error(): + async def receive_and_expect_error() -> None: with pytest.raises(BrokenResourceError) as excinfo: await server.receive_some(10) @@ -1099,14 +1099,14 @@ async def receive_and_expect_error(): nursery.start_soon(receive_and_expect_error) -async def test_https_mode_eof_before_handshake(client_ctx): +async def test_https_mode_eof_before_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": True}, client_kwargs={"https_compatible": True}, ) - async def server_expect_clean_eof(): + async def server_expect_clean_eof() -> None: assert await server.receive_some(10) == b"" async with _core.open_nursery() as nursery: @@ -1114,7 +1114,7 @@ async def server_expect_clean_eof(): nursery.start_soon(server_expect_clean_eof) -async def test_send_error_during_handshake(client_ctx): +async def test_send_error_during_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) async def bad_hook(): @@ -1131,7 +1131,7 @@ async def bad_hook(): await client.do_handshake() -async def test_receive_error_during_handshake(client_ctx): +async def test_receive_error_during_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) async def bad_hook(): @@ -1139,7 +1139,7 @@ async def bad_hook(): client.transport_stream.receive_stream.receive_some_hook = bad_hook - async def client_side(cancel_scope): + async def client_side(cancel_scope) -> None: with pytest.raises(KeyError): with assert_checkpoints(): await client.do_handshake() @@ -1154,7 +1154,7 @@ async def client_side(cancel_scope): await client.do_handshake() -async def test_selected_alpn_protocol_before_handshake(client_ctx): +async def test_selected_alpn_protocol_before_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1164,7 +1164,7 @@ async def test_selected_alpn_protocol_before_handshake(client_ctx): server.selected_alpn_protocol() -async def test_selected_alpn_protocol_when_not_set(client_ctx): +async def test_selected_alpn_protocol_when_not_set(client_ctx) -> None: # ALPN protocol still returns None when it's not set, # instead of raising an exception client, server = ssl_memory_stream_pair(client_ctx) @@ -1179,7 +1179,7 @@ async def test_selected_alpn_protocol_when_not_set(client_ctx): assert client.selected_alpn_protocol() == server.selected_alpn_protocol() -async def test_selected_npn_protocol_before_handshake(client_ctx): +async def test_selected_npn_protocol_before_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1193,7 +1193,7 @@ async def test_selected_npn_protocol_before_handshake(client_ctx): r"ignore: ssl module. NPN is deprecated, use ALPN instead:UserWarning", r"ignore:ssl NPN is deprecated, use ALPN instead:DeprecationWarning", ) -async def test_selected_npn_protocol_when_not_set(client_ctx): +async def test_selected_npn_protocol_when_not_set(client_ctx) -> None: # NPN protocol still returns None when it's not set, # instead of raising an exception client, server = ssl_memory_stream_pair(client_ctx) @@ -1208,7 +1208,7 @@ async def test_selected_npn_protocol_when_not_set(client_ctx): assert client.selected_npn_protocol() == server.selected_npn_protocol() -async def test_get_channel_binding_before_handshake(client_ctx): +async def test_get_channel_binding_before_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1218,7 +1218,7 @@ async def test_get_channel_binding_before_handshake(client_ctx): server.get_channel_binding() -async def test_get_channel_binding_after_handshake(client_ctx): +async def test_get_channel_binding_after_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) async with _core.open_nursery() as nursery: @@ -1231,7 +1231,7 @@ async def test_get_channel_binding_after_handshake(client_ctx): assert client.get_channel_binding() == server.get_channel_binding() -async def test_getpeercert(client_ctx): +async def test_getpeercert(client_ctx) -> None: # Make sure we're not affected by https://bugs.python.org/issue29334 client, server = ssl_memory_stream_pair(client_ctx) @@ -1244,7 +1244,7 @@ async def test_getpeercert(client_ctx): assert ("DNS", "trio-test-1.example.org") in client.getpeercert()["subjectAltName"] -async def test_SSLListener(client_ctx): +async def test_SSLListener(client_ctx) -> None: async def setup(**kwargs): listen_sock = tsocket.socket() await listen_sock.bind(("127.0.0.1", 0)) diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index 17cf740012..2d805a1814 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -97,7 +97,7 @@ async def run_process_in_nursery(*args, **kwargs): @background_process_param -async def test_basic(background_process): +async def test_basic(background_process) -> None: async with background_process(EXIT_TRUE) as proc: await proc.wait() assert isinstance(proc, Process) @@ -114,7 +114,7 @@ async def test_basic(background_process): @background_process_param -async def test_auto_update_returncode(background_process): +async def test_auto_update_returncode(background_process) -> None: async with background_process(SLEEP(9999)) as p: assert p.returncode is None assert "running" in repr(p) @@ -127,7 +127,7 @@ async def test_auto_update_returncode(background_process): @background_process_param -async def test_multi_wait(background_process): +async def test_multi_wait(background_process) -> None: async with background_process(SLEEP(10)) as proc: # Check that wait (including multi-wait) tolerates being cancelled async with _core.open_nursery() as nursery: @@ -147,7 +147,7 @@ async def test_multi_wait(background_process): # Test for deprecated 'async with process:' semantics -async def test_async_with_basics_deprecated(recwarn): +async def test_async_with_basics_deprecated(recwarn) -> None: async with await open_process( CAT, stdin=subprocess.PIPE, stdout=subprocess.PIPE ) as proc: @@ -160,7 +160,7 @@ async def test_async_with_basics_deprecated(recwarn): # Test for deprecated 'async with process:' semantics -async def test_kill_when_context_cancelled(recwarn): +async def test_kill_when_context_cancelled(recwarn) -> None: with move_on_after(100) as scope: async with await open_process(SLEEP(10)) as proc: assert proc.poll() is None @@ -181,7 +181,7 @@ async def test_kill_when_context_cancelled(recwarn): @background_process_param -async def test_pipes(background_process): +async def test_pipes(background_process) -> None: async with background_process( COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, stdin=subprocess.PIPE, @@ -190,11 +190,11 @@ async def test_pipes(background_process): ) as proc: msg = b"the quick brown fox jumps over the lazy dog" - async def feed_input(): + async def feed_input() -> None: await proc.stdin.send_all(msg) await proc.stdin.aclose() - async def check_output(stream, expected): + async def check_output(stream, expected) -> None: seen = bytearray() async for chunk in stream: seen += chunk @@ -212,7 +212,7 @@ async def check_output(stream, expected): @background_process_param -async def test_interactive(background_process): +async def test_interactive(background_process) -> None: # Test some back-and-forth with a subprocess. This one works like so: # in: 32\n # out: 0000...0000\n (32 zeroes) @@ -241,10 +241,10 @@ async def test_interactive(background_process): ) as proc: newline = b"\n" if posix else b"\r\n" - async def expect(idx, request): + async def expect(idx, request) -> None: async with _core.open_nursery() as nursery: - async def drain_one(stream, count, digit): + async def drain_one(stream, count, digit) -> None: while count > 0: result = await stream.receive_some(count) assert result == (f"{digit}".encode() * len(result)) @@ -279,7 +279,7 @@ async def drain_one(stream, count, digit): assert proc.returncode == 0 -async def test_run(): +async def test_run() -> None: data = bytes(random.randint(0, 255) for _ in range(2**18)) result = await run_process( @@ -322,7 +322,7 @@ async def test_run(): await run_process(CAT, capture_stderr=True, stderr=None) -async def test_run_check(): +async def test_run_check() -> None: cmd = python("sys.stderr.buffer.write(b'test\\n'); sys.exit(1)") with pytest.raises(subprocess.CalledProcessError) as excinfo: await run_process(cmd, stdin=subprocess.DEVNULL, capture_stderr=True) @@ -341,7 +341,7 @@ async def test_run_check(): @skip_if_fbsd_pipes_broken -async def test_run_with_broken_pipe(): +async def test_run_with_broken_pipe() -> None: result = await run_process( [sys.executable, "-c", "import sys; sys.stdin.close()"], stdin=b"x" * 131072 ) @@ -350,7 +350,7 @@ async def test_run_with_broken_pipe(): @background_process_param -async def test_stderr_stdout(background_process): +async def test_stderr_stdout(background_process) -> None: async with background_process( COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, stdin=subprocess.PIPE, @@ -416,7 +416,7 @@ async def test_stderr_stdout(background_process): os.close(r) -async def test_errors(): +async def test_errors() -> None: with pytest.raises(TypeError) as excinfo: await open_process(["ls"], encoding="utf-8") assert "unbuffered byte streams" in str(excinfo.value) @@ -430,8 +430,8 @@ async def test_errors(): @background_process_param -async def test_signals(background_process): - async def test_one_signal(send_it, signum): +async def test_signals(background_process) -> None: + async def test_one_signal(send_it, signum) -> None: with move_on_after(1.0) as scope: async with background_process(SLEEP(3600)) as proc: send_it(proc) @@ -457,7 +457,7 @@ async def test_one_signal(send_it, signum): @pytest.mark.skipif(not posix, reason="POSIX specific") @background_process_param -async def test_wait_reapable_fails(background_process): +async def test_wait_reapable_fails(background_process) -> None: old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN) try: # With SIGCHLD disabled, the wait() syscall will wait for the @@ -476,7 +476,7 @@ async def test_wait_reapable_fails(background_process): @slow -def test_waitid_eintr(): +def test_waitid_eintr() -> None: # This only matters on PyPy (where we're coding EINTR handling # ourselves) but the test works on all waitid platforms. from .._subprocess_platform import wait_child_exiting @@ -488,7 +488,7 @@ def test_waitid_eintr(): got_alarm = False sleeper = subprocess.Popen(["sleep", "3600"]) - def on_alarm(sig, frame): + def on_alarm(sig, frame) -> None: nonlocal got_alarm got_alarm = True sleeper.kill() @@ -507,10 +507,10 @@ def on_alarm(sig, frame): signal.signal(signal.SIGALRM, old_sigalrm) -async def test_custom_deliver_cancel(): +async def test_custom_deliver_cancel() -> None: custom_deliver_cancel_called = False - async def custom_deliver_cancel(proc): + async def custom_deliver_cancel(proc) -> None: nonlocal custom_deliver_cancel_called custom_deliver_cancel_called = True proc.terminate() @@ -531,7 +531,7 @@ async def custom_deliver_cancel(proc): assert custom_deliver_cancel_called -async def test_warn_on_failed_cancel_terminate(monkeypatch): +async def test_warn_on_failed_cancel_terminate(monkeypatch) -> None: original_terminate = Process.terminate def broken_terminate(self): @@ -548,7 +548,7 @@ def broken_terminate(self): @pytest.mark.skipif(not posix, reason="posix only") -async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch): +async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch) -> None: monkeypatch.setattr(Process, "terminate", lambda *args: None) with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"): @@ -560,7 +560,7 @@ async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch): # the background_process_param exercises a lot of run_process cases, but it uses # check=False, so lets have a test that uses check=True as well -async def test_run_process_background_fail(): +async def test_run_process_background_fail() -> None: with pytest.raises(subprocess.CalledProcessError): async with _core.open_nursery() as nursery: proc = await nursery.start(run_process, EXIT_FALSE) @@ -571,7 +571,7 @@ async def test_run_process_background_fail(): not SyncPath("/dev/fd").exists(), reason="requires a way to iterate through open files", ) -async def test_for_leaking_fds(): +async def test_for_leaking_fds() -> None: starting_fds = set(SyncPath("/dev/fd").iterdir()) await run_process(EXIT_TRUE) assert set(SyncPath("/dev/fd").iterdir()) == starting_fds @@ -586,7 +586,7 @@ async def test_for_leaking_fds(): # regression test for #2209 -async def test_subprocess_pidfd_unnotified(): +async def test_subprocess_pidfd_unnotified() -> None: noticed_exit = None async def wait_and_tell(proc: Process) -> None: diff --git a/trio/_tests/test_sync.py b/trio/_tests/test_sync.py index 7de42b86f9..4325e33a6c 100644 --- a/trio/_tests/test_sync.py +++ b/trio/_tests/test_sync.py @@ -8,7 +8,7 @@ from ..testing import assert_checkpoints, wait_all_tasks_blocked -async def test_Event(): +async def test_Event() -> None: e = Event() assert not e.is_set() assert e.statistics().tasks_waiting == 0 @@ -22,7 +22,7 @@ async def test_Event(): record = [] - async def child(): + async def child() -> None: record.append("sleeping") await e.wait() record.append("woken") @@ -38,7 +38,7 @@ async def child(): assert record == ["sleeping", "sleeping", "woken", "woken"] -async def test_CapacityLimiter(): +async def test_CapacityLimiter() -> None: with pytest.raises(TypeError): CapacityLimiter(1.0) with pytest.raises(ValueError): @@ -107,7 +107,7 @@ async def test_CapacityLimiter(): c.release_on_behalf_of("value 1") -async def test_CapacityLimiter_inf(): +async def test_CapacityLimiter_inf() -> None: from math import inf c = CapacityLimiter(inf) @@ -123,7 +123,7 @@ async def test_CapacityLimiter_inf(): assert c.available_tokens == inf -async def test_CapacityLimiter_change_total_tokens(): +async def test_CapacityLimiter_change_total_tokens() -> None: c = CapacityLimiter(2) with pytest.raises(TypeError): @@ -160,7 +160,7 @@ async def test_CapacityLimiter_change_total_tokens(): # regression test for issue #548 -async def test_CapacityLimiter_memleak_548(): +async def test_CapacityLimiter_memleak_548() -> None: limiter = CapacityLimiter(total_tokens=1) await limiter.acquire() @@ -174,7 +174,7 @@ async def test_CapacityLimiter_memleak_548(): assert len(limiter._pending_borrowers) == 0 -async def test_Semaphore(): +async def test_Semaphore() -> None: with pytest.raises(TypeError): Semaphore(1.0) with pytest.raises(ValueError): @@ -204,7 +204,7 @@ async def test_Semaphore(): record = [] - async def do_acquire(s): + async def do_acquire(s) -> None: record.append("started") await s.acquire() record.append("finished") @@ -222,7 +222,7 @@ async def do_acquire(s): assert record == ["started", "finished"] -async def test_Semaphore_bounded(): +async def test_Semaphore_bounded() -> None: with pytest.raises(TypeError): Semaphore(1, max_value=1.0) with pytest.raises(ValueError): @@ -240,7 +240,7 @@ async def test_Semaphore_bounded(): @pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__) -async def test_Lock_and_StrictFIFOLock(lockcls): +async def test_Lock_and_StrictFIFOLock(lockcls) -> None: l = lockcls() # noqa assert not l.locked() @@ -277,7 +277,7 @@ async def test_Lock_and_StrictFIFOLock(lockcls): holder_task = None - async def holder(): + async def holder() -> None: nonlocal holder_task holder_task = _core.current_task() async with l: @@ -315,7 +315,7 @@ async def holder(): assert statistics.tasks_waiting == 0 -async def test_Condition(): +async def test_Condition() -> None: with pytest.raises(TypeError): Condition(Semaphore(1)) with pytest.raises(TypeError): @@ -349,7 +349,7 @@ async def test_Condition(): finished_waiters = set() - async def waiter(i): + async def waiter(i) -> None: async with c: await c.wait() finished_waiters.add(i) @@ -407,56 +407,56 @@ async def waiter(i): class ChannelLock1(AsyncContextManagerMixin): - def __init__(self, capacity): + def __init__(self, capacity) -> None: self.s, self.r = open_memory_channel(capacity) for _ in range(capacity - 1): self.s.send_nowait(None) - def acquire_nowait(self): + def acquire_nowait(self) -> None: self.s.send_nowait(None) - async def acquire(self): + async def acquire(self) -> None: await self.s.send(None) - def release(self): + def release(self) -> None: self.r.receive_nowait() class ChannelLock2(AsyncContextManagerMixin): - def __init__(self): + def __init__(self) -> None: self.s, self.r = open_memory_channel(10) self.s.send_nowait(None) - def acquire_nowait(self): + def acquire_nowait(self) -> None: self.r.receive_nowait() - async def acquire(self): + async def acquire(self) -> None: await self.r.receive() - def release(self): + def release(self) -> None: self.s.send_nowait(None) class ChannelLock3(AsyncContextManagerMixin): - def __init__(self): + def __init__(self) -> None: self.s, self.r = open_memory_channel(0) # self.acquired is true when one task acquires the lock and # only becomes false when it's released and no tasks are # waiting to acquire. self.acquired = False - def acquire_nowait(self): + def acquire_nowait(self) -> None: assert not self.acquired self.acquired = True - async def acquire(self): + async def acquire(self) -> None: if self.acquired: await self.s.send(None) else: self.acquired = True await _core.checkpoint() - def release(self): + def release(self) -> None: try: self.r.receive_nowait() except _core.WouldBlock: @@ -493,13 +493,13 @@ def release(self): # Spawn a bunch of workers that take a lock and then yield; make sure that # only one worker is ever in the critical section at a time. @generic_lock_test -async def test_generic_lock_exclusion(lock_factory): +async def test_generic_lock_exclusion(lock_factory) -> None: LOOPS = 10 WORKERS = 5 in_critical_section = False acquires = 0 - async def worker(lock_like): + async def worker(lock_like) -> None: nonlocal in_critical_section, acquires for _ in range(LOOPS): async with lock_like: @@ -522,12 +522,12 @@ async def worker(lock_like): # Several workers queue on the same lock; make sure they each get it, in # order. @generic_lock_test -async def test_generic_lock_fifo_fairness(lock_factory): +async def test_generic_lock_fifo_fairness(lock_factory) -> None: initial_order = [] record = [] LOOPS = 5 - async def loopy(name, lock_like): + async def loopy(name, lock_like) -> None: # Record the order each task was initially scheduled in initial_order.append(name) for _ in range(LOOPS): @@ -546,12 +546,12 @@ async def loopy(name, lock_like): @generic_lock_test -async def test_generic_lock_acquire_nowait_blocks_acquire(lock_factory): +async def test_generic_lock_acquire_nowait_blocks_acquire(lock_factory) -> None: lock_like = lock_factory() record = [] - async def lock_taker(): + async def lock_taker() -> None: record.append("started") async with lock_like: pass diff --git a/trio/_tests/test_testing.py b/trio/_tests/test_testing.py index 26d9298807..e7e0f95535 100644 --- a/trio/_tests/test_testing.py +++ b/trio/_tests/test_testing.py @@ -13,15 +13,15 @@ from ..testing._memory_streams import _UnboundedByteQueue -async def test_wait_all_tasks_blocked(): +async def test_wait_all_tasks_blocked() -> None: record = [] - async def busy_bee(): + async def busy_bee() -> None: for _ in range(10): await _core.checkpoint() record.append("busy bee exhausted") - async def waiting_for_bee_to_leave(): + async def waiting_for_bee_to_leave() -> None: await wait_all_tasks_blocked() record.append("quiet at last!") @@ -33,7 +33,7 @@ async def waiting_for_bee_to_leave(): # check cancellation record = [] - async def cancelled_while_waiting(): + async def cancelled_while_waiting() -> None: try: await wait_all_tasks_blocked() except _core.Cancelled: @@ -45,10 +45,10 @@ async def cancelled_while_waiting(): assert record == ["ok"] -async def test_wait_all_tasks_blocked_with_timeouts(mock_clock): +async def test_wait_all_tasks_blocked_with_timeouts(mock_clock) -> None: record = [] - async def timeout_task(): + async def timeout_task() -> None: record.append("tt start") await sleep(5) record.append("tt finished") @@ -62,25 +62,25 @@ async def timeout_task(): assert record == ["tt start", "tt finished"] -async def test_wait_all_tasks_blocked_with_cushion(): +async def test_wait_all_tasks_blocked_with_cushion() -> None: record = [] - async def blink(): + async def blink() -> None: record.append("blink start") await sleep(0.01) await sleep(0.01) await sleep(0.01) record.append("blink end") - async def wait_no_cushion(): + async def wait_no_cushion() -> None: await wait_all_tasks_blocked() record.append("wait_no_cushion end") - async def wait_small_cushion(): + async def wait_small_cushion() -> None: await wait_all_tasks_blocked(0.02) record.append("wait_small_cushion end") - async def wait_big_cushion(): + async def wait_big_cushion() -> None: await wait_all_tasks_blocked(0.03) record.append("wait_big_cushion end") @@ -104,7 +104,7 @@ async def wait_big_cushion(): ################################################################ -async def test_assert_checkpoints(recwarn): +async def test_assert_checkpoints(recwarn) -> None: with assert_checkpoints(): await _core.checkpoint() @@ -130,7 +130,7 @@ async def test_assert_checkpoints(recwarn): await _core.cancel_shielded_checkpoint() -async def test_assert_no_checkpoints(recwarn): +async def test_assert_no_checkpoints(recwarn) -> None: with assert_no_checkpoints(): 1 + 1 @@ -160,14 +160,14 @@ async def test_assert_no_checkpoints(recwarn): ################################################################ -async def test_Sequencer(): +async def test_Sequencer() -> None: record = [] - def t(val): + def t(val) -> None: print(val) record.append(val) - async def f1(seq): + async def f1(seq) -> None: async with seq(1): t(("f1", 1)) async with seq(3): @@ -175,7 +175,7 @@ async def f1(seq): async with seq(4): t(("f1", 4)) - async def f2(seq): + async def f2(seq) -> None: async with seq(0): t(("f2", 0)) async with seq(2): @@ -198,12 +198,12 @@ async def f2(seq): pass # pragma: no cover -async def test_Sequencer_cancel(): +async def test_Sequencer_cancel() -> None: # Killing a blocked task makes everything blow up record = [] seq = Sequencer() - async def child(i): + async def child(i) -> None: with _core.CancelScope() as scope: if i == 1: scope.cancel() @@ -245,7 +245,7 @@ async def test__assert_raises(): # This is a private implementation detail, but it's complex enough to be worth # testing directly -async def test__UnboundeByteQueue(): +async def test__UnboundeByteQueue() -> None: ubq = _UnboundedByteQueue() ubq.put(b"123") @@ -273,11 +273,11 @@ async def test__UnboundeByteQueue(): with assert_checkpoints(): assert await ubq.get() == b"efghi" - async def putter(data): + async def putter(data) -> None: await wait_all_tasks_blocked() ubq.put(data) - async def getter(expect): + async def getter(expect) -> None: with assert_checkpoints(): assert await ubq.get() == expect @@ -308,7 +308,7 @@ async def getter(expect): # close wakes up blocked getters ubq2 = _UnboundedByteQueue() - async def closer(): + async def closer() -> None: await wait_all_tasks_blocked() ubq2.close() @@ -317,10 +317,10 @@ async def closer(): nursery.start_soon(closer) -async def test_MemorySendStream(): +async def test_MemorySendStream() -> None: mss = MemorySendStream() - async def do_send_all(data): + async def do_send_all(data) -> None: with assert_checkpoints(): await mss.send_all(data) @@ -346,7 +346,7 @@ async def do_send_all(data): # and we don't know which one will get the error. resource_busy_count = 0 - async def do_send_all_count_resourcebusy(): + async def do_send_all_count_resourcebusy() -> None: nonlocal resource_busy_count try: await do_send_all(b"xxx") @@ -375,15 +375,15 @@ async def do_send_all_count_resourcebusy(): record = [] - async def send_all_hook(): + async def send_all_hook() -> None: # hook runs after send_all does its work (can pull data out) assert mss2.get_data_nowait() == b"abc" record.append("send_all_hook") - async def wait_send_all_might_not_block_hook(): + async def wait_send_all_might_not_block_hook() -> None: record.append("wait_send_all_might_not_block_hook") - def close_hook(): + def close_hook() -> None: record.append("close_hook") mss2 = MemorySendStream( @@ -407,7 +407,7 @@ def close_hook(): ] -async def test_MemoryReceiveStream(): +async def test_MemoryReceiveStream() -> None: mrs = MemoryReceiveStream() async def do_receive_some(max_bytes): @@ -438,12 +438,12 @@ async def do_receive_some(max_bytes): with pytest.raises(_core.ClosedResourceError): mrs.put_data(b"---") - async def receive_some_hook(): + async def receive_some_hook() -> None: mrs2.put_data(b"xxx") record = [] - def close_hook(): + def close_hook() -> None: record.append("closed") mrs2 = MemoryReceiveStream(receive_some_hook, close_hook) @@ -468,7 +468,7 @@ def close_hook(): await mrs2.receive_some(10) -async def test_MemoryRecvStream_closing(): +async def test_MemoryRecvStream_closing() -> None: mrs = MemoryReceiveStream() # close with no pending data mrs.close() @@ -488,7 +488,7 @@ async def test_MemoryRecvStream_closing(): await mrs2.receive_some(10) -async def test_memory_stream_pump(): +async def test_memory_stream_pump() -> None: mss = MemorySendStream() mrs = MemoryReceiveStream() @@ -512,7 +512,7 @@ async def test_memory_stream_pump(): assert await mrs.receive_some(10) == b"" -async def test_memory_stream_one_way_pair(): +async def test_memory_stream_one_way_pair() -> None: s, r = memory_stream_one_way_pair() assert s.send_all_hook is not None assert s.wait_send_all_might_not_block_hook is None @@ -521,7 +521,7 @@ async def test_memory_stream_one_way_pair(): await s.send_all(b"123") assert await r.receive_some(10) == b"123" - async def receiver(expected): + async def receiver(expected) -> None: assert await r.receive_some(10) == expected # This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook @@ -549,11 +549,11 @@ async def receiver(expected): s.send_all_hook = None await s.send_all(b"456") - async def cancel_after_idle(nursery): + async def cancel_after_idle(nursery) -> None: await wait_all_tasks_blocked() nursery.cancel_scope.cancel() - async def check_for_cancel(): + async def check_for_cancel() -> None: with pytest.raises(_core.Cancelled): # This should block forever... or until cancelled. Even though we # sent some data on the send stream. @@ -568,7 +568,7 @@ async def check_for_cancel(): assert await r.receive_some(10) == b"456789" -async def test_memory_stream_pair(): +async def test_memory_stream_pair() -> None: a, b = memory_stream_pair() await a.send_all(b"123") await b.send_all(b"abc") @@ -578,11 +578,11 @@ async def test_memory_stream_pair(): await a.send_eof() assert await b.receive_some(10) == b"" - async def sender(): + async def sender() -> None: await wait_all_tasks_blocked() await b.send_all(b"xyz") - async def receiver(): + async def receiver() -> None: assert await a.receive_some(10) == b"xyz" async with _core.open_nursery() as nursery: @@ -590,7 +590,7 @@ async def receiver(): nursery.start_soon(sender) -async def test_memory_streams_with_generic_tests(): +async def test_memory_streams_with_generic_tests() -> None: async def one_way_stream_maker(): return memory_stream_one_way_pair() @@ -602,7 +602,7 @@ async def half_closeable_stream_maker(): await check_half_closeable_stream(half_closeable_stream_maker, None) -async def test_lockstep_streams_with_generic_tests(): +async def test_lockstep_streams_with_generic_tests() -> None: async def one_way_stream_maker(): return lockstep_stream_one_way_pair() @@ -614,8 +614,8 @@ async def two_way_stream_maker(): await check_two_way_stream(two_way_stream_maker, two_way_stream_maker) -async def test_open_stream_to_socket_listener(): - async def check(listener): +async def test_open_stream_to_socket_listener() -> None: + async def check(listener) -> None: async with listener: client_stream = await open_stream_to_socket_listener(listener) async with client_stream: diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 9e448a4d38..246b50533f 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -27,13 +27,13 @@ from ..testing import wait_all_tasks_blocked -async def test_do_in_trio_thread(): +async def test_do_in_trio_thread() -> None: trio_thread = threading.current_thread() - async def check_case(do_in_trio_thread, fn, expected, trio_token=None): + async def check_case(do_in_trio_thread, fn, expected, trio_token=None) -> None: record = [] - def threadfn(): + def threadfn() -> None: try: record.append(("start", threading.current_thread())) x = do_in_trio_thread(fn, record, trio_token=trio_token) @@ -51,7 +51,7 @@ def threadfn(): token = _core.current_trio_token() - def f(record): + def f(record) -> int: assert not _core.currently_ki_protected() record.append(("f", threading.current_thread())) return 2 @@ -65,7 +65,7 @@ def f(record): await check_case(from_thread_run_sync, f, ("error", ValueError), trio_token=token) - async def f(record): + async def f(record) -> int: assert not _core.currently_ki_protected() await _core.checkpoint() record.append(("f", threading.current_thread())) @@ -82,26 +82,26 @@ async def f(record): await check_case(from_thread_run, f, ("error", KeyError), trio_token=token) -async def test_do_in_trio_thread_from_trio_thread(): +async def test_do_in_trio_thread_from_trio_thread() -> None: with pytest.raises(RuntimeError): from_thread_run_sync(lambda: None) # pragma: no branch - async def foo(): # pragma: no cover + async def foo() -> None: # pragma: no cover pass with pytest.raises(RuntimeError): from_thread_run(foo) -def test_run_in_trio_thread_ki(): +def test_run_in_trio_thread_ki() -> None: # if we get a control-C during a run_in_trio_thread, then it propagates # back to the caller (slick!) record = set() - async def check_run_in_trio_thread(): + async def check_run_in_trio_thread() -> None: token = _core.current_trio_token() - def trio_thread_fn(): + def trio_thread_fn() -> None: print("in Trio thread") assert not _core.currently_ki_protected() print("ki_self") @@ -112,10 +112,10 @@ def trio_thread_fn(): print("finally", sys.exc_info()) - async def trio_thread_afn(): + async def trio_thread_afn() -> None: trio_thread_fn() - def external_thread_fn(): + def external_thread_fn() -> None: try: print("running") from_thread_run_sync(trio_thread_fn, trio_token=token) @@ -141,16 +141,16 @@ def external_thread_fn(): assert record == {"ok1", "ok2"} -def test_await_in_trio_thread_while_main_exits(): +def test_await_in_trio_thread_while_main_exits() -> None: record = [] ev = Event() - async def trio_fn(): + async def trio_fn() -> None: record.append("sleeping") ev.set() await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) - def thread_fn(token): + def thread_fn(token) -> None: try: from_thread_run(trio_fn, trio_token=token) except _core.Cancelled: @@ -169,7 +169,7 @@ async def main(): assert record == ["sleeping", "cancelled"] -async def test_named_thread(): +async def test_named_thread() -> None: ending = " from trio._tests.test_threads.test_named_thread" def inner(name: str = "inner" + ending) -> threading.Thread: @@ -236,7 +236,7 @@ def _get_thread_name(ident: Optional[int] = None) -> Optional[str]: # this depends on pthread being available, which is the case on 99.9% of linux machines # and most mac machines. So unless the platform is linux it will just skip # in case it fails to fetch the os thread name. -async def test_named_thread_os(): +async def test_named_thread_os() -> None: def inner(name: str) -> threading.Thread: os_thread_name = _get_thread_name() if os_thread_name is None and sys.platform != "linux": @@ -271,7 +271,7 @@ async def test_thread_name(name: str, expected: Optional[str] = None) -> None: await test_thread_name("💙", expected="?") -async def test_has_pthread_setname_np(): +async def test_has_pthread_setname_np() -> None: from trio._core._thread_cache import get_os_thread_name_func k = get_os_thread_name_func() @@ -280,7 +280,7 @@ async def test_has_pthread_setname_np(): pytest.skip(f"no pthread_setname_np on {sys.platform}") -async def test_run_in_worker_thread(): +async def test_run_in_worker_thread() -> None: trio_thread = threading.current_thread() def f(x): @@ -299,10 +299,10 @@ def g(): assert excinfo.value.args[0] != trio_thread -async def test_run_in_worker_thread_cancellation(): +async def test_run_in_worker_thread_cancellation() -> None: register = [None] - def f(q): + def f(q) -> None: # Make the thread block for a controlled amount of time register[0] = "blocking" q.get() @@ -359,18 +359,18 @@ async def child(q, cancellable): # Make sure that if trio.run exits, and then the thread finishes, then that's # handled gracefully. (Requires that the thread result machinery be prepared # for call_soon to raise RunFinishedError.) -def test_run_in_worker_thread_abandoned(capfd, monkeypatch): +def test_run_in_worker_thread_abandoned(capfd, monkeypatch) -> None: monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01) q1 = stdlib_queue.Queue() q2 = stdlib_queue.Queue() - def thread_fn(): + def thread_fn() -> None: q1.get() q2.put(threading.current_thread()) - async def main(): - async def child(): + async def main() -> None: + async def child() -> None: await to_thread_run_sync(thread_fn, cancellable=True) async with _core.open_nursery() as nursery: @@ -397,7 +397,7 @@ async def child(): @pytest.mark.parametrize("MAX", [3, 5, 10]) @pytest.mark.parametrize("cancel", [False, True]) @pytest.mark.parametrize("use_default_limiter", [False, True]) -async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter): +async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter) -> None: # This test is a bit tricky. The goal is to make sure that if we set # limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever # running at a time, even if there are more concurrent calls to @@ -436,7 +436,7 @@ class state: token = _core.current_trio_token() - def thread_fn(cancel_scope): + def thread_fn(cancel_scope) -> None: print("thread_fn start") from_thread_run_sync(cancel_scope.cancel, trio_token=token) with lock: @@ -452,7 +452,7 @@ def thread_fn(cancel_scope): state.running -= 1 print("thread_fn exiting") - async def run_thread(event): + async def run_thread(event) -> None: with _core.CancelScope() as cancel_scope: await to_thread_run_sync( thread_fn, cancel_scope, limiter=limiter_arg, cancellable=cancel @@ -501,17 +501,17 @@ async def run_thread(event): c.total_tokens = orig_total_tokens -async def test_run_in_worker_thread_custom_limiter(): +async def test_run_in_worker_thread_custom_limiter() -> None: # Basically just checking that we only call acquire_on_behalf_of and # release_on_behalf_of, since that's part of our documented API. record = [] class CustomLimiter: - async def acquire_on_behalf_of(self, borrower): + async def acquire_on_behalf_of(self, borrower) -> None: record.append("acquire") self._borrower = borrower - def release_on_behalf_of(self, borrower): + def release_on_behalf_of(self, borrower) -> None: record.append("release") assert borrower == self._borrower @@ -519,11 +519,11 @@ def release_on_behalf_of(self, borrower): assert record == ["acquire", "release"] -async def test_run_in_worker_thread_limiter_error(): +async def test_run_in_worker_thread_limiter_error() -> None: record = [] class BadCapacityLimiter: - async def acquire_on_behalf_of(self, borrower): + async def acquire_on_behalf_of(self, borrower) -> None: record.append("acquire") def release_on_behalf_of(self, borrower): @@ -547,7 +547,7 @@ def release_on_behalf_of(self, borrower): assert record == ["acquire", "release"] -async def test_run_in_worker_thread_fail_to_spawn(monkeypatch): +async def test_run_in_worker_thread_fail_to_spawn(monkeypatch) -> None: # Test the unlikely but possible case where trying to spawn a thread fails def bad_start(self, *args): raise RuntimeError("the engines canna take it captain") @@ -565,7 +565,7 @@ def bad_start(self, *args): assert limiter.borrowed_tokens == 0 -async def test_trio_to_thread_run_sync_token(): +async def test_trio_to_thread_run_sync_token() -> None: # Test that to_thread_run_sync automatically injects the current trio token # into a spawned thread def thread_fn(): @@ -577,9 +577,9 @@ def thread_fn(): assert callee_token == caller_token -async def test_trio_to_thread_run_sync_expected_error(): +async def test_trio_to_thread_run_sync_expected_error() -> None: # Test correct error when passed async function - async def async_fn(): # pragma: no cover + async def async_fn() -> None: # pragma: no cover pass with pytest.raises(TypeError, match="expected a sync function"): @@ -591,7 +591,7 @@ async def async_fn(): # pragma: no cover ) -async def test_trio_to_thread_run_sync_contextvars(): +async def test_trio_to_thread_run_sync_contextvars() -> None: trio_thread = threading.current_thread() trio_test_contextvar.set("main") @@ -628,7 +628,7 @@ def g(): assert sniffio.current_async_library() == "trio" -async def test_trio_from_thread_run_sync(): +async def test_trio_from_thread_run_sync() -> None: # Test that to_thread_run_sync correctly "hands off" the trio token to # trio.from_thread.run_sync() def thread_fn(): @@ -639,26 +639,26 @@ def thread_fn(): assert isinstance(trio_time, float) # Test correct error when passed async function - async def async_fn(): # pragma: no cover + async def async_fn() -> None: # pragma: no cover pass - def thread_fn(): + def thread_fn() -> None: from_thread_run_sync(async_fn) with pytest.raises(TypeError, match="expected a sync function"): await to_thread_run_sync(thread_fn) -async def test_trio_from_thread_run(): +async def test_trio_from_thread_run() -> None: # Test that to_thread_run_sync correctly "hands off" the trio token to # trio.from_thread.run() record = [] - async def back_in_trio_fn(): + async def back_in_trio_fn() -> None: _core.current_time() # implicitly checks that we're in trio record.append("back in trio") - def thread_fn(): + def thread_fn() -> None: record.append("in thread") from_thread_run(back_in_trio_fn) @@ -666,14 +666,14 @@ def thread_fn(): assert record == ["in thread", "back in trio"] # Test correct error when passed sync function - def sync_fn(): # pragma: no cover + def sync_fn() -> None: # pragma: no cover pass with pytest.raises(TypeError, match="appears to be synchronous"): await to_thread_run_sync(from_thread_run, sync_fn) -async def test_trio_from_thread_token(): +async def test_trio_from_thread_token() -> None: # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() # share the same Trio token def thread_fn(): @@ -685,7 +685,7 @@ def thread_fn(): assert callee_token == caller_token -async def test_trio_from_thread_token_kwarg(): +async def test_trio_from_thread_token_kwarg() -> None: # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can # use an explicitly defined token def thread_fn(token): @@ -697,7 +697,7 @@ def thread_fn(token): assert callee_token == caller_token -async def test_from_thread_no_token(): +async def test_from_thread_no_token() -> None: # Test that a "raw call" to trio.from_thread.run() fails because no token # has been provided @@ -705,7 +705,7 @@ async def test_from_thread_no_token(): from_thread_run_sync(_core.current_time) -async def test_trio_from_thread_run_sync_contextvars(): +async def test_trio_from_thread_run_sync_contextvars() -> None: trio_test_contextvar.set("main") def thread_fn(): @@ -748,7 +748,7 @@ def back_in_main(): assert back_current_value == "back_in_main" -async def test_trio_from_thread_run_contextvars(): +async def test_trio_from_thread_run_contextvars() -> None: trio_test_contextvar.set("main") def thread_fn(): @@ -791,13 +791,13 @@ async def async_back_in_main(): assert sniffio.current_async_library() == "trio" -def test_run_fn_as_system_task_catched_badly_typed_token(): +def test_run_fn_as_system_task_catched_badly_typed_token() -> None: with pytest.raises(RuntimeError): from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype") -async def test_from_thread_inside_trio_thread(): - def not_called(): # pragma: no cover +async def test_from_thread_inside_trio_thread() -> None: + def not_called() -> None: # pragma: no cover assert False trio_token = _core.current_trio_token() @@ -806,7 +806,7 @@ def not_called(): # pragma: no cover @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -def test_from_thread_run_during_shutdown(): +def test_from_thread_run_during_shutdown() -> None: save = [] record = [] @@ -818,7 +818,7 @@ async def agen(): await to_thread_run_sync(from_thread_run, sleep, 0) record.append("ok") - async def main(): + async def main() -> None: save.append(agen()) await save[-1].asend(None) @@ -826,18 +826,18 @@ async def main(): assert record == ["ok"] -async def test_trio_token_weak_referenceable(): +async def test_trio_token_weak_referenceable() -> None: token = current_trio_token() assert isinstance(token, TrioToken) weak_reference = weakref.ref(token) assert token is weak_reference() -async def test_unsafe_cancellable_kwarg(): +async def test_unsafe_cancellable_kwarg() -> None: # This is a stand in for a numpy ndarray or other objects # that (maybe surprisingly) lack a notion of truthiness class BadBool: - def __bool__(self): + def __bool__(self) -> bool: raise NotImplementedError with pytest.raises(NotImplementedError): diff --git a/trio/_tests/test_timeouts.py b/trio/_tests/test_timeouts.py index 9507d88a78..1491067c05 100644 --- a/trio/_tests/test_timeouts.py +++ b/trio/_tests/test_timeouts.py @@ -43,13 +43,13 @@ async def check_takes_about(f, expected_dur): @slow -async def test_sleep(): - async def sleep_1(): +async def test_sleep() -> None: + async def sleep_1() -> None: await sleep_until(_core.current_time() + TARGET) await check_takes_about(sleep_1, TARGET) - async def sleep_2(): + async def sleep_2() -> None: await sleep(TARGET) await check_takes_about(sleep_2, TARGET) @@ -63,8 +63,8 @@ async def sleep_2(): @slow -async def test_move_on_after(): - async def sleep_3(): +async def test_move_on_after() -> None: + async def sleep_3() -> None: with move_on_after(TARGET): await sleep(100) @@ -72,8 +72,8 @@ async def sleep_3(): @slow -async def test_fail(): - async def sleep_4(): +async def test_fail() -> None: + async def sleep_4() -> None: with fail_at(_core.current_time() + TARGET): await sleep(100) @@ -83,7 +83,7 @@ async def sleep_4(): with fail_at(_core.current_time() + 100): await sleep(0) - async def sleep_5(): + async def sleep_5() -> None: with fail_after(TARGET): await sleep(100) @@ -94,7 +94,7 @@ async def sleep_5(): await sleep(0) -async def test_timeouts_raise_value_error(): +async def test_timeouts_raise_value_error() -> None: # deadlines are allowed to be negative, but not delays. # neither delays nor deadlines are allowed to be NaN diff --git a/trio/_tests/test_tracing.py b/trio/_tests/test_tracing.py index e5110eaff3..0cef2b0f44 100644 --- a/trio/_tests/test_tracing.py +++ b/trio/_tests/test_tracing.py @@ -25,7 +25,7 @@ async def coro3_async_gen(event: trio.Event) -> None: pass -async def test_task_iter_await_frames(): +async def test_task_iter_await_frames() -> None: async with trio.open_nursery() as nursery: event = trio.Event() nursery.start_soon(coro3, event) @@ -42,7 +42,7 @@ async def test_task_iter_await_frames(): nursery.cancel_scope.cancel() -async def test_task_iter_await_frames_async_gen(): +async def test_task_iter_await_frames_async_gen() -> None: async with trio.open_nursery() as nursery: event = trio.Event() nursery.start_soon(coro3_async_gen, event) diff --git a/trio/_tests/test_unix_pipes.py b/trio/_tests/test_unix_pipes.py index 0b0d2ceb23..f29ee241c6 100644 --- a/trio/_tests/test_unix_pipes.py +++ b/trio/_tests/test_unix_pipes.py @@ -57,7 +57,7 @@ async def make_clogged_pipe(): return s, r -async def test_send_pipe(): +async def test_send_pipe() -> None: r, w = os.pipe() async with FdStream(w) as send: assert send.fileno() == w @@ -67,7 +67,7 @@ async def test_send_pipe(): os.close(r) -async def test_receive_pipe(): +async def test_receive_pipe() -> None: r, w = os.pipe() async with FdStream(r) as recv: assert (recv.fileno()) == r @@ -77,15 +77,15 @@ async def test_receive_pipe(): os.close(w) -async def test_pipes_combined(): +async def test_pipes_combined() -> None: write, read = await make_pipe() count = 2**20 - async def sender(): + async def sender() -> None: big = bytearray(count) await write.send_all(big) - async def reader(): + async def reader() -> None: await wait_all_tasks_blocked() received = 0 while received < count: @@ -101,7 +101,7 @@ async def reader(): await write.aclose() -async def test_pipe_errors(): +async def test_pipe_errors() -> None: with pytest.raises(TypeError): FdStream(None) @@ -112,7 +112,7 @@ async def test_pipe_errors(): await s.receive_some(0) -async def test_del(): +async def test_del() -> None: w, r = await make_pipe() f1, f2 = w.fileno(), r.fileno() del w, r @@ -127,7 +127,7 @@ async def test_del(): assert excinfo.value.errno == errno.EBADF -async def test_async_with(): +async def test_async_with() -> None: w, r = await make_pipe() async with w, r: pass @@ -144,7 +144,7 @@ async def test_async_with(): assert excinfo.value.errno == errno.EBADF -async def test_misdirected_aclose_regression(): +async def test_misdirected_aclose_regression() -> None: # https://github.com/python-trio/trio/issues/661#issuecomment-456582356 w, r = await make_pipe() old_r_fd = r.fileno() @@ -164,7 +164,7 @@ async def test_misdirected_aclose_regression(): # And now set up a background task that's working on the new receive # handle - async def expect_eof(): + async def expect_eof() -> None: assert await r2.receive_some(10) == b"" async with _core.open_nursery() as nursery: @@ -182,7 +182,7 @@ async def expect_eof(): os.close(w2_fd) -async def test_close_at_bad_time_for_receive_some(monkeypatch): +async def test_close_at_bad_time_for_receive_some(monkeypatch) -> None: # We used to have race conditions where if one task was using the pipe, # and another closed it at *just* the wrong moment, it would give an # unexpected error instead of ClosedResourceError: @@ -190,13 +190,13 @@ async def test_close_at_bad_time_for_receive_some(monkeypatch): # # This tests what happens if the pipe gets closed in the moment *between* # when receive_some wakes up, and when it tries to call os.read - async def expect_closedresourceerror(): + async def expect_closedresourceerror() -> None: with pytest.raises(_core.ClosedResourceError): await r.receive_some(10) orig_wait_readable = _core._run.TheIOManager.wait_readable - async def patched_wait_readable(*args, **kwargs): + async def patched_wait_readable(*args, **kwargs) -> None: await orig_wait_readable(*args, **kwargs) await r.aclose() @@ -210,7 +210,7 @@ async def patched_wait_readable(*args, **kwargs): await s.send_all(b"x") -async def test_close_at_bad_time_for_send_all(monkeypatch): +async def test_close_at_bad_time_for_send_all(monkeypatch) -> None: # We used to have race conditions where if one task was using the pipe, # and another closed it at *just* the wrong moment, it would give an # unexpected error instead of ClosedResourceError: @@ -218,13 +218,13 @@ async def test_close_at_bad_time_for_send_all(monkeypatch): # # This tests what happens if the pipe gets closed in the moment *between* # when send_all wakes up, and when it tries to call os.write - async def expect_closedresourceerror(): + async def expect_closedresourceerror() -> None: with pytest.raises(_core.ClosedResourceError): await s.send_all(b"x" * 100) orig_wait_writable = _core._run.TheIOManager.wait_writable - async def patched_wait_writable(*args, **kwargs): + async def patched_wait_writable(*args, **kwargs) -> None: await orig_wait_writable(*args, **kwargs) await s.aclose() @@ -257,7 +257,7 @@ async def patched_wait_writable(*args, **kwargs): sys.platform.startswith("freebsd"), reason="no way to make read() return a bizarro error on FreeBSD", ) -async def test_bizarro_OSError_from_receive(): +async def test_bizarro_OSError_from_receive() -> None: # Make sure that if the read syscall returns some bizarro error, then we # get a BrokenResourceError. This is incredibly unlikely; there's almost # no way to trigger a failure here intentionally (except for EBADF, but we @@ -277,5 +277,5 @@ async def test_bizarro_OSError_from_receive(): @skip_if_fbsd_pipes_broken -async def test_pipe_fully(): +async def test_pipe_fully() -> None: await check_one_way_stream(make_pipe, make_clogged_pipe) diff --git a/trio/_tests/test_util.py b/trio/_tests/test_util.py index 9cffaa30d3..ef99e8f66f 100644 --- a/trio/_tests/test_util.py +++ b/trio/_tests/test_util.py @@ -24,10 +24,10 @@ from ..testing import wait_all_tasks_blocked -def test_signal_raise(): +def test_signal_raise() -> None: record = [] - def handler(signum, _): + def handler(signum, _) -> None: record.append(signum) old = signal.signal(signal.SIGFPE, handler) @@ -38,7 +38,7 @@ def handler(signum, _): assert record == [signal.SIGFPE] -async def test_ConflictDetector(): +async def test_ConflictDetector() -> None: ul1 = ConflictDetector("ul1") ul2 = ConflictDetector("ul2") @@ -52,7 +52,7 @@ async def test_ConflictDetector(): pass # pragma: no cover assert "ul1" in str(excinfo.value) - async def wait_with_ul1(): + async def wait_with_ul1() -> None: with ul1: await wait_all_tasks_blocked() @@ -63,7 +63,7 @@ async def wait_with_ul1(): assert "ul1" in str(excinfo.value) -def test_module_metadata_is_fixed_up(): +def test_module_metadata_is_fixed_up() -> None: import trio import trio.testing @@ -87,10 +87,10 @@ def test_module_metadata_is_fixed_up(): assert trio.to_thread.run_sync.__qualname__ == "run_sync" -async def test_is_main_thread(): +async def test_is_main_thread() -> None: assert is_main_thread() - def not_main_thread(): + def not_main_thread() -> None: assert not is_main_thread() await trio.to_thread.run_sync(not_main_thread) @@ -98,13 +98,13 @@ def not_main_thread(): # @coroutine is deprecated since python 3.8, which is fine with us. @pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning") -def test_coroutine_or_error(): +def test_coroutine_or_error() -> None: class Deferred: "Just kidding" with ignore_coroutine_never_awaited_warnings(): - async def f(): # pragma: no cover + async def f() -> None: # pragma: no cover pass with pytest.raises(TypeError) as excinfo: @@ -156,7 +156,7 @@ async def async_gen(arg): # pragma: no cover del excinfo -def test_generic_function(): +def test_generic_function() -> None: @generic_function def test_func(arg): """Look, a docstring!""" @@ -187,11 +187,11 @@ class SubClass(FinalClass): # type: ignore[misc] pass -def test_no_public_constructor_metaclass(): +def test_no_public_constructor_metaclass() -> None: """The NoPublicConstructor metaclass prevents calling the constructor directly.""" class SpecialClass(metaclass=NoPublicConstructor): - def __init__(self, a: int, b: float): + def __init__(self, a: int, b: float) -> None: """Check arguments can be passed to __init__.""" assert a == 8 assert b == 3.14 @@ -203,7 +203,7 @@ def __init__(self, a: int, b: float): assert isinstance(SpecialClass._create(8, b=3.14), SpecialClass) -def test_fixup_module_metadata(): +def test_fixup_module_metadata() -> None: # Ignores modules not in the trio.X tree. non_trio_module = types.ModuleType("not_trio") non_trio_module.some_func = lambda: None diff --git a/trio/_tests/test_wait_for_object.py b/trio/_tests/test_wait_for_object.py index ea16684289..53e771b7ed 100644 --- a/trio/_tests/test_wait_for_object.py +++ b/trio/_tests/test_wait_for_object.py @@ -16,7 +16,7 @@ from .._wait_for_object import WaitForMultipleObjects_sync, WaitForSingleObject -async def test_WaitForMultipleObjects_sync(): +async def test_WaitForMultipleObjects_sync() -> None: # This does a series of tests where we set/close the handle before # initiating the waiting for it. # @@ -70,7 +70,7 @@ async def test_WaitForMultipleObjects_sync(): @slow -async def test_WaitForMultipleObjects_sync_slow(): +async def test_WaitForMultipleObjects_sync_slow() -> None: # This does a series of test in which the main thread sync-waits for # handles, while we spawn a thread to set the handles after a short while. @@ -125,7 +125,7 @@ async def test_WaitForMultipleObjects_sync_slow(): print("test_WaitForMultipleObjects_sync_slow thread-set second OK") -async def test_WaitForSingleObject(): +async def test_WaitForSingleObject() -> None: # This does a series of test for setting/closing the handle before # initiating the wait. @@ -160,7 +160,7 @@ async def test_WaitForSingleObject(): @slow -async def test_WaitForSingleObject_slow(): +async def test_WaitForSingleObject_slow() -> None: # This does a series of test for setting the handle in another task, # and cancelling the wait task. @@ -168,7 +168,7 @@ async def test_WaitForSingleObject_slow(): # the timeout with a certain margin. TIMEOUT = 0.3 - async def signal_soon_async(handle): + async def signal_soon_async(handle) -> None: await _timeouts.sleep(TIMEOUT) kernel32.SetEvent(handle) diff --git a/trio/_tests/test_windows_pipes.py b/trio/_tests/test_windows_pipes.py index 5c4bae7d25..399d7116bb 100644 --- a/trio/_tests/test_windows_pipes.py +++ b/trio/_tests/test_windows_pipes.py @@ -24,14 +24,14 @@ async def make_pipe() -> Tuple[PipeSendStream, PipeReceiveStream]: return PipeSendStream(w), PipeReceiveStream(r) -async def test_pipe_typecheck(): +async def test_pipe_typecheck() -> None: with pytest.raises(TypeError): PipeSendStream(1.0) with pytest.raises(TypeError): PipeReceiveStream(None) -async def test_pipe_error_on_close(): +async def test_pipe_error_on_close() -> None: # Make sure we correctly handle a failure from kernel32.CloseHandle r, w = pipe() @@ -47,18 +47,18 @@ async def test_pipe_error_on_close(): await receive_stream.aclose() -async def test_pipes_combined(): +async def test_pipes_combined() -> None: write, read = await make_pipe() count = 2**20 replicas = 3 - async def sender(): + async def sender() -> None: async with write: big = bytearray(count) for _ in range(replicas): await write.send_all(big) - async def reader(): + async def reader() -> None: async with read: await wait_all_tasks_blocked() total_received = 0 @@ -76,7 +76,7 @@ async def reader(): n.start_soon(reader) -async def test_async_with(): +async def test_async_with() -> None: w, r = await make_pipe() async with w, r: pass @@ -87,11 +87,11 @@ async def test_async_with(): await r.receive_some(10) -async def test_close_during_write(): +async def test_close_during_write() -> None: w, r = await make_pipe() async with _core.open_nursery() as nursery: - async def write_forever(): + async def write_forever() -> None: with pytest.raises(_core.ClosedResourceError) as excinfo: while True: await w.send_all(b"x" * 4096) @@ -102,7 +102,7 @@ async def write_forever(): await w.aclose() -async def test_pipe_fully(): +async def test_pipe_fully() -> None: # passing make_clogged_pipe tests wait_send_all_might_not_block, and we # can't implement that on Windows await check_one_way_stream(make_pipe, None) diff --git a/trio/_tests/tools/test_gen_exports.py b/trio/_tests/tools/test_gen_exports.py index b4e23916a0..ee51827a59 100644 --- a/trio/_tests/tools/test_gen_exports.py +++ b/trio/_tests/tools/test_gen_exports.py @@ -60,12 +60,12 @@ async def not_public_async(self): """ -def test_get_public_methods(): +def test_get_public_methods() -> None: methods = list(get_public_methods(ast.parse(SOURCE))) assert {m.name for m in methods} == {"public_func", "public_async_func"} -def test_create_pass_through_args(): +def test_create_pass_through_args() -> None: testcases = [ ("def f()", "()"), ("def f(one)", "(one)"), @@ -91,7 +91,7 @@ def test_create_pass_through_args(): @skip_lints @pytest.mark.parametrize("imports", ["", IMPORT_1, IMPORT_2, IMPORT_3]) -def test_process(tmp_path, imports): +def test_process(tmp_path, imports) -> None: try: import black # noqa: F401 # there's no dedicated CI run that has astor+isort, but lacks black. From a909b0a3f7a4311b66a7dde578c28f4f84d327db Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 15 Oct 2023 15:22:57 +0200 Subject: [PATCH 03/35] pyannotate --aggressive --- trio/_tests/test_exports.py | 2 +- trio/_tests/test_path.py | 6 +++--- trio/_tests/test_scheduler_determinism.py | 2 +- trio/_tests/test_ssl.py | 2 +- trio/_tests/test_subprocess.py | 4 ++-- trio/_tests/test_sync.py | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py index 7a016e86d3..6861ba3708 100644 --- a/trio/_tests/test_exports.py +++ b/trio/_tests/test_exports.py @@ -126,7 +126,7 @@ def iter_modules( # https://github.com/pypa/setuptools/issues/3274 "ignore:module 'sre_constants' is deprecated:DeprecationWarning", ) -def test_static_tool_sees_all_symbols(tool, modname, tmpdir) -> None: +def test_static_tool_sees_all_symbols(tool, modname: str, tmpdir) -> None: module = importlib.import_module(modname) def no_underscores(symbols): diff --git a/trio/_tests/test_path.py b/trio/_tests/test_path.py index 1d17029bdf..bb24764c23 100644 --- a/trio/_tests/test_path.py +++ b/trio/_tests/test_path.py @@ -14,7 +14,7 @@ def path(tmpdir): return trio.Path(p) -def method_pair(path, method_name): +def method_pair(path, method_name: str): path = pathlib.Path(path) async_path = trio.Path(path) return getattr(path, method_name), getattr(async_path, method_name) @@ -103,7 +103,7 @@ async def test_async_method_signature(path) -> None: @pytest.mark.parametrize("method_name", ["is_dir", "is_file"]) -async def test_compare_async_stat_methods(method_name) -> None: +async def test_compare_async_stat_methods(method_name: str) -> None: method, async_method = method_pair(".", method_name) result = method() @@ -118,7 +118,7 @@ async def test_invalid_name_not_wrapped(path) -> None: @pytest.mark.parametrize("method_name", ["absolute", "resolve"]) -async def test_async_methods_rewrap(method_name) -> None: +async def test_async_methods_rewrap(method_name: str) -> None: method, async_method = method_pair(".", method_name) result = method() diff --git a/trio/_tests/test_scheduler_determinism.py b/trio/_tests/test_scheduler_determinism.py index 1c438f136c..4c0da698ce 100644 --- a/trio/_tests/test_scheduler_determinism.py +++ b/trio/_tests/test_scheduler_determinism.py @@ -5,7 +5,7 @@ async def scheduler_trace(): """Returns a scheduler-dependent value we can use to check determinism.""" trace = [] - async def tracer(name) -> None: + async def tracer(name: str) -> None: for i in range(50): trace.append((name, i)) await trio.sleep(0) diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py index 06ac2e694c..8af210a0ff 100644 --- a/trio/_tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -88,7 +88,7 @@ def client_ctx(request): # The blocking socket server. -def ssl_echo_serve_sync(sock, *, expect_fail=False): +def ssl_echo_serve_sync(sock, *, expect_fail: bool = False): try: wrapped = SERVER_CTX.wrap_socket( sock, server_side=True, suppress_ragged_eofs=False diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index 2d805a1814..0f993976e0 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -241,10 +241,10 @@ async def test_interactive(background_process) -> None: ) as proc: newline = b"\n" if posix else b"\r\n" - async def expect(idx, request) -> None: + async def expect(idx: int, request) -> None: async with _core.open_nursery() as nursery: - async def drain_one(stream, count, digit) -> None: + async def drain_one(stream, count: int, digit) -> None: while count > 0: result = await stream.receive_some(count) assert result == (f"{digit}".encode() * len(result)) diff --git a/trio/_tests/test_sync.py b/trio/_tests/test_sync.py index 4325e33a6c..7747740c9f 100644 --- a/trio/_tests/test_sync.py +++ b/trio/_tests/test_sync.py @@ -527,7 +527,7 @@ async def test_generic_lock_fifo_fairness(lock_factory) -> None: record = [] LOOPS = 5 - async def loopy(name, lock_like) -> None: + async def loopy(name: str, lock_like) -> None: # Record the order each task was initially scheduled in initial_order.append(name) for _ in range(LOOPS): From 6cad1937c4c78811cf38755f175e53e2155e2abc Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 17 Oct 2023 12:54:46 +0200 Subject: [PATCH 04/35] WIP before merging origin/master --- pyproject.toml | 25 ++--- trio/_core/_tests/test_guest_mode.py | 98 ++++++++++--------- trio/_core/_tests/test_instrumentation.py | 10 +- trio/_core/_tests/test_ki.py | 92 ++++++++--------- trio/_core/_tests/test_mock_clock.py | 22 +++-- trio/_core/_tests/test_multierror.py | 37 +++---- trio/_core/_tests/test_parking_lot.py | 34 ++++--- trio/_core/_tests/test_run.py | 4 +- trio/_core/_tests/test_thread_cache.py | 36 +++---- trio/_core/_tests/tutil.py | 8 +- trio/_tests/test_file_io.py | 2 +- trio/_tests/test_highlevel_generic.py | 10 +- trio/_tests/test_highlevel_serve_listeners.py | 12 ++- trio/_tests/test_highlevel_ssl_helpers.py | 14 ++- trio/_tests/test_path.py | 50 +++++----- trio/_tests/test_ssl.py | 13 ++- trio/_tests/test_subprocess.py | 8 +- trio/_tests/test_sync.py | 16 +-- trio/_tests/test_testing.py | 2 +- trio/_tests/test_threads.py | 64 +++++++----- trio/_tests/test_util.py | 43 ++++---- trio/_tests/test_windows_pipes.py | 18 ++-- 22 files changed, 332 insertions(+), 286 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e12f95c309..81e8232965 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,10 +80,14 @@ disallow_untyped_defs = true check_untyped_defs = true disallow_untyped_calls = false - -# partially typed tests +# files not yet fully typed [[tool.mypy.overrides]] module = [ +# internal +"trio/_windows_pipes", + +# tests +"trio/testing/_fake_net", "trio/_core/_tests/test_guest_mode", "trio/_core/_tests/test_ki", "trio/_core/_tests/test_mock_clock", @@ -92,22 +96,7 @@ module = [ "trio/_core/_tests/test_multierror_scripts/simple_excepthook", "trio/_core/_tests/test_parking_lot", "trio/_core/_tests/test_thread_cache", -] -check_untyped_defs = true -disallow_any_decorated = false -disallow_any_generics = false -disallow_any_unimported = false -disallow_incomplete_defs = false -disallow_untyped_defs = false - -# files not yet fully typed -[[tool.mypy.overrides]] -module = [ -# internal -"trio/_windows_pipes", - -# tests -"trio/testing/_fake_net", +"trio/_core/_tests/test_unbounded_queue", "trio/_tests/test_exports", "trio/_tests/test_file_io", "trio/_tests/test_highlevel_generic", diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py index 21980b9b72..3789cdefa3 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -35,7 +35,7 @@ def trivial_guest_run(trio_fn, *, in_host_after_start=None, **start_guest_run_kw host_thread = threading.current_thread() - def run_sync_soon_threadsafe(fn: Callable): + def run_sync_soon_threadsafe(fn: Callable) -> None: nonlocal todo if host_thread is threading.current_thread(): # pragma: no cover crash = partial( @@ -44,7 +44,7 @@ def run_sync_soon_threadsafe(fn: Callable): todo.put(("run", crash)) todo.put(("run", fn)) - def run_sync_soon_not_threadsafe(fn: Callable): + def run_sync_soon_not_threadsafe(fn: Callable) -> None: nonlocal todo if host_thread is not threading.current_thread(): # pragma: no cover crash = partial( @@ -53,7 +53,7 @@ def run_sync_soon_not_threadsafe(fn: Callable): todo.put(("run", crash)) todo.put(("run", fn)) - def done_callback(outcome: Outcome): + def done_callback(outcome: Outcome) -> None: nonlocal todo todo.put(("unwrap", outcome)) @@ -84,8 +84,8 @@ def done_callback(outcome: Outcome): del todo, run_sync_soon_threadsafe, done_callback -def test_guest_trivial(): - async def trio_return(in_host): +def test_guest_trivial() -> None: + async def trio_return(in_host) -> str: await trio.sleep(0) return "ok" @@ -98,14 +98,14 @@ async def trio_fail(in_host): trivial_guest_run(trio_fail) -def test_guest_can_do_io(): - async def trio_main(in_host): +def test_guest_can_do_io() -> None: + async def trio_main(in_host) -> None: record = [] a, b = trio.socket.socketpair() with a, b: async with trio.open_nursery() as nursery: - async def do_receive(): + async def do_receive() -> None: record.append(await a.recv(1)) nursery.start_soon(do_receive) @@ -118,17 +118,17 @@ async def do_receive(): trivial_guest_run(trio_main) -def test_guest_is_initialized_when_start_returns(): +def test_guest_is_initialized_when_start_returns() -> None: trio_token = None record = [] - async def trio_main(in_host): + async def trio_main(in_host) -> str: record.append("main task ran") await trio.sleep(0) assert trio.lowlevel.current_trio_token() is trio_token return "ok" - def after_start(): + def after_start() -> None: # We should get control back before the main task executes any code assert record == [] @@ -137,7 +137,7 @@ def after_start(): trio_token.run_sync_soon(record.append, "run_sync_soon cb ran") @trio.lowlevel.spawn_system_task - async def early_task(): + async def early_task() -> None: record.append("system task ran") await trio.sleep(0) @@ -153,7 +153,7 @@ class BadClock: def start_clock(self): raise ValueError("whoops") - def after_start_never_runs(): # pragma: no cover + def after_start_never_runs() -> None: # pragma: no cover pytest.fail("shouldn't get here") trivial_guest_run( @@ -161,8 +161,8 @@ def after_start_never_runs(): # pragma: no cover ) -def test_host_can_directly_wake_trio_task(): - async def trio_main(in_host): +def test_host_can_directly_wake_trio_task() -> None: + async def trio_main(in_host) -> str: ev = trio.Event() in_host(ev.set) await ev.wait() @@ -171,11 +171,11 @@ async def trio_main(in_host): assert trivial_guest_run(trio_main) == "ok" -def test_host_altering_deadlines_wakes_trio_up(): - def set_deadline(cscope, new_deadline): +def test_host_altering_deadlines_wakes_trio_up() -> None: + def set_deadline(cscope, new_deadline) -> None: cscope.deadline = new_deadline - async def trio_main(in_host): + async def trio_main(in_host) -> str: with trio.CancelScope() as cscope: in_host(lambda: set_deadline(cscope, -inf)) await trio.sleep_forever() @@ -194,11 +194,11 @@ async def trio_main(in_host): assert trivial_guest_run(trio_main) == "ok" -def test_guest_mode_sniffio_integration(): +def test_guest_mode_sniffio_integration() -> None: from sniffio import current_async_library, thread_local as sniffio_library - async def trio_main(in_host): - async def synchronize(): + async def trio_main(in_host) -> str: + async def synchronize() -> None: """Wait for all in_host() calls issued so far to complete.""" evt = trio.Event() in_host(evt.set) @@ -223,10 +223,10 @@ async def synchronize(): sniffio_library.name = None -def test_warn_set_wakeup_fd_overwrite(): +def test_warn_set_wakeup_fd_overwrite() -> None: assert signal.set_wakeup_fd(-1) == -1 - async def trio_main(in_host): + async def trio_main(in_host) -> str: return "ok" a, b = socket.socketpair() @@ -268,7 +268,7 @@ async def trio_main(in_host): signal.set_wakeup_fd(a.fileno()) try: - async def trio_check_wakeup_fd_unaltered(in_host): + async def trio_check_wakeup_fd_unaltered(in_host) -> str: fd = signal.set_wakeup_fd(-1) assert fd == a.fileno() signal.set_wakeup_fd(fd) @@ -287,19 +287,19 @@ async def trio_check_wakeup_fd_unaltered(in_host): assert signal.set_wakeup_fd(-1) == a.fileno() -def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked(): +def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked() -> None: # This is designed to hit the branch in unrolled_run where: # idle_primed=True # runner.runq is empty # events is Truth-y # ...and confirm that in this case, wait_all_tasks_blocked does not get # triggered. - def set_deadline(cscope, new_deadline): + def set_deadline(cscope, new_deadline) -> None: print(f"setting deadline {new_deadline}") cscope.deadline = new_deadline - async def trio_main(in_host): - async def sit_in_wait_all_tasks_blocked(watb_cscope): + async def trio_main(in_host) -> str: + async def sit_in_wait_all_tasks_blocked(watb_cscope) -> None: with watb_cscope: # Overall point of this test is that this # wait_all_tasks_blocked should *not* return normally, but @@ -308,7 +308,7 @@ async def sit_in_wait_all_tasks_blocked(watb_cscope): assert False # pragma: no cover assert watb_cscope.cancelled_caught - async def get_woken_by_host_deadline(watb_cscope): + async def get_woken_by_host_deadline(watb_cscope) -> None: with trio.CancelScope() as cscope: print("scheduling stuff to happen") @@ -360,13 +360,13 @@ def after_io_wait(self, timeout: float) -> None: @restore_unraisablehook() -def test_guest_warns_if_abandoned(): +def test_guest_warns_if_abandoned() -> None: # This warning is emitted from the garbage collector. So we have to make # sure that our abandoned run is garbage. The easiest way to do this is to # put it into a function, so that we're sure all the local state, # traceback frames, etc. are garbage once it returns. - def do_abandoned_guest_run(): - async def abandoned_main(in_host): + def do_abandoned_guest_run() -> None: + async def abandoned_main(in_host) -> None: in_host(lambda: 1 / 0) while True: await trio.sleep(0) @@ -401,13 +401,13 @@ async def abandoned_main(in_host): trio.current_time() -def aiotrio_run(trio_fn, *, pass_not_threadsafe=True, **start_guest_run_kwargs): +def aiotrio_run(trio_fn, *, pass_not_threadsafe: bool = True, **start_guest_run_kwargs): loop = asyncio.new_event_loop() async def aio_main(): trio_done_fut = loop.create_future() - def trio_done_callback(main_outcome): + def trio_done_callback(main_outcome) -> None: print(f"trio_fn finished: {main_outcome!r}") trio_done_fut.set_result(main_outcome) @@ -429,8 +429,8 @@ def trio_done_callback(main_outcome): loop.close() -def test_guest_mode_on_asyncio(): - async def trio_main(): +def test_guest_mode_on_asyncio() -> None: + async def trio_main() -> str: print("trio_main!") to_trio, from_aio = trio.open_memory_channel[int](float("inf")) @@ -451,6 +451,8 @@ async def trio_main(): aio_task.cancel() return "trio-main-done" + raise AssertionError("should never be reached") + async def aio_pingpong(from_trio, to_trio): print("aio_pingpong!") @@ -488,10 +490,10 @@ async def aio_pingpong(from_trio, to_trio): ) -def test_guest_mode_internal_errors(monkeypatch, recwarn): +def test_guest_mode_internal_errors(monkeypatch, recwarn) -> None: with monkeypatch.context() as m: - async def crash_in_run_loop(in_host): + async def crash_in_run_loop(in_host) -> None: m.setattr("trio._core._run.GLOBAL_RUN_CONTEXT.runner.runq", "HI") await trio.sleep(1) @@ -500,7 +502,7 @@ async def crash_in_run_loop(in_host): with monkeypatch.context() as m: - async def crash_in_io(in_host): + async def crash_in_io(in_host) -> None: m.setattr("trio._core._run.TheIOManager.get_events", None) await trio.sleep(0) @@ -509,7 +511,7 @@ async def crash_in_io(in_host): with monkeypatch.context() as m: - async def crash_in_worker_thread_io(in_host): + async def crash_in_worker_thread_io(in_host) -> None: t = threading.current_thread() old_get_events = trio._core._run.TheIOManager.get_events @@ -529,11 +531,11 @@ def bad_get_events(*args): gc_collect_harder() -def test_guest_mode_ki(): +def test_guest_mode_ki() -> None: assert signal.getsignal(signal.SIGINT) is signal.default_int_handler # Check SIGINT in Trio func and in host func - async def trio_main(in_host): + async def trio_main(in_host) -> None: with pytest.raises(KeyboardInterrupt): signal_raise(signal.SIGINT) @@ -561,7 +563,7 @@ async def trio_main_raising(in_host): assert signal.getsignal(signal.SIGINT) is signal.default_int_handler -def test_guest_mode_autojump_clock_threshold_changing(): +def test_guest_mode_autojump_clock_threshold_changing() -> None: # This is super obscure and probably no-one will ever notice, but # technically mutating the MockClock.autojump_threshold from the host # should wake up the guest, so let's test it. @@ -570,7 +572,7 @@ def test_guest_mode_autojump_clock_threshold_changing(): DURATION = 120 - async def trio_main(in_host): + async def trio_main(in_host) -> None: assert trio.current_time() == 0 in_host(lambda: setattr(clock, "autojump_threshold", 0)) await trio.sleep(DURATION) @@ -586,12 +588,12 @@ async def trio_main(in_host): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="PyPy 7.2 is buggy") @restore_unraisablehook() -def test_guest_mode_asyncgens(): +def test_guest_mode_asyncgens() -> None: import sniffio record = set() - async def agen(label): + async def agen(label: str): assert sniffio.current_async_library() == label try: yield 1 @@ -603,10 +605,10 @@ async def agen(label): pass record.add((label, library)) - async def iterate_in_aio(): + async def iterate_in_aio() -> None: await agen("asyncio").asend(None) - async def trio_main(): + async def trio_main() -> None: task = asyncio.ensure_future(iterate_in_aio()) done_evt = trio.Event() task.add_done_callback(lambda _: done_evt.set()) diff --git a/trio/_core/_tests/test_instrumentation.py b/trio/_core/_tests/test_instrumentation.py index 8b103dbfdc..ecd7585ef2 100644 --- a/trio/_core/_tests/test_instrumentation.py +++ b/trio/_core/_tests/test_instrumentation.py @@ -75,9 +75,9 @@ async def main() -> None: # reschedules the task immediately upon yielding, before the # after_task_step event fires. expected = ( - [("before_run",), ("schedule", task)] + [("before_run", None), ("schedule", task)] + [("before", task), ("schedule", task), ("after", task)] * 5 - + [("before", task), ("after", task), ("after_run",)] + + [("before", task), ("after", task), ("after_run", None)] ) assert r1.record == r2.record + r3.record assert task is not None @@ -104,7 +104,7 @@ async def main() -> None: _core.run(main, instruments=[r]) expected = [ - ("before_run",), + ("before_run", None), ("schedule", tasks["t1"]), ("schedule", tasks["t2"]), { @@ -121,7 +121,7 @@ async def main() -> None: ("before", tasks["t2"]), ("after", tasks["t2"]), }, - ("after_run",), + ("after_run", None), ] print(list(r.filter_tasks(tasks.values()))) check_sequence_matches(list(r.filter_tasks(tasks.values())), expected) @@ -199,7 +199,7 @@ async def main() -> Task: # the TaskRecorder kept going throughout, even though the BrokenInstrument # was disabled assert ("after", main_task) in r.record - assert ("after_run",) in r.record + assert ("after_run", None) in r.record # And we got a log message assert caplog.records[0].exc_info is not None exc_type, exc_value, exc_traceback = caplog.records[0].exc_info diff --git a/trio/_core/_tests/test_ki.py b/trio/_core/_tests/test_ki.py index 0fda688194..60a6b17336 100644 --- a/trio/_core/_tests/test_ki.py +++ b/trio/_core/_tests/test_ki.py @@ -24,16 +24,16 @@ from ..._core import Abort, RaiseCancelT -def ki_self(): +def ki_self() -> None: signal_raise(signal.SIGINT) -def test_ki_self(): +def test_ki_self() -> None: with pytest.raises(KeyboardInterrupt): ki_self() -async def test_ki_enabled(): +async def test_ki_enabled() -> None: # Regular tasks aren't KI-protected assert not _core.currently_ki_protected() @@ -41,7 +41,7 @@ async def test_ki_enabled(): token = _core.current_trio_token() record = [] - def check(): + def check() -> None: record.append(_core.currently_ki_protected()) token.run_sync_soon(check) @@ -49,23 +49,23 @@ def check(): assert record == [True] @_core.enable_ki_protection - def protected(): + def protected() -> None: assert _core.currently_ki_protected() unprotected() @_core.disable_ki_protection - def unprotected(): + def unprotected() -> None: assert not _core.currently_ki_protected() protected() @_core.enable_ki_protection - async def aprotected(): + async def aprotected() -> None: assert _core.currently_ki_protected() await aunprotected() @_core.disable_ki_protection - async def aunprotected(): + async def aunprotected() -> None: assert not _core.currently_ki_protected() await aprotected() @@ -102,16 +102,16 @@ def gen_unprotected(): # .throw(), not the actual caller. So child() here would have a caller deep in # the guts of the run loop, and always be protected, even when it shouldn't # have been. (Solution: we don't use .throw() anymore.) -async def test_ki_enabled_after_yield_briefly(): +async def test_ki_enabled_after_yield_briefly() -> None: @_core.enable_ki_protection - async def protected(): + async def protected() -> None: await child(True) @_core.disable_ki_protection - async def unprotected(): + async def unprotected() -> None: await child(False) - async def child(expected): + async def child(expected) -> None: import traceback traceback.print_stack() @@ -146,10 +146,10 @@ def protected_manager(): @pytest.mark.skipif(async_generator is None, reason="async_generator not installed") -async def test_async_generator_agen_protection(): +async def test_async_generator_agen_protection() -> None: @_core.enable_ki_protection - @async_generator - async def agen_protected1(): + @async_generator # type: ignore[misc] # untyped generator + async def agen_protected1() -> None: assert _core.currently_ki_protected() try: await yield_() @@ -157,8 +157,8 @@ async def agen_protected1(): assert _core.currently_ki_protected() @_core.disable_ki_protection - @async_generator - async def agen_unprotected1(): + @async_generator # type: ignore[misc] # untyped generator + async def agen_unprotected1() -> None: assert not _core.currently_ki_protected() try: await yield_() @@ -166,18 +166,18 @@ async def agen_unprotected1(): assert not _core.currently_ki_protected() # Swap the order of the decorators: - @async_generator + @async_generator # type: ignore[misc] # untyped generator @_core.enable_ki_protection - async def agen_protected2(): + async def agen_protected2() -> None: assert _core.currently_ki_protected() try: await yield_() finally: assert _core.currently_ki_protected() - @async_generator + @async_generator # type: ignore[misc] # untyped generator @_core.disable_ki_protection - async def agen_unprotected2(): + async def agen_unprotected2() -> None: assert not _core.currently_ki_protected() try: await yield_() @@ -190,7 +190,7 @@ async def agen_unprotected2(): await _check_agen(agen_unprotected2) -async def test_native_agen_protection(): +async def test_native_agen_protection() -> None: # Native async generators @_core.enable_ki_protection async def agen_protected(): @@ -230,20 +230,20 @@ async def _check_agen(agen_fn): # Test the case where there's no magic local anywhere in the call stack -def test_ki_disabled_out_of_context(): +def test_ki_disabled_out_of_context() -> None: assert _core.currently_ki_protected() -def test_ki_disabled_in_del(): +def test_ki_disabled_in_del() -> None: def nestedfunction(): return _core.currently_ki_protected() - def __del__(): + def __del__() -> None: assert _core.currently_ki_protected() assert nestedfunction() @_core.disable_ki_protection - def outerfunction(): + def outerfunction() -> None: assert not _core.currently_ki_protected() assert not nestedfunction() __del__() @@ -253,15 +253,15 @@ def outerfunction(): assert nestedfunction() -def test_ki_protection_works(): - async def sleeper(name, record): +def test_ki_protection_works() -> None: + async def sleeper(name: str, record) -> None: try: while True: await _core.checkpoint() except _core.Cancelled: record.add(name + " ok") - async def raiser(name, record): + async def raiser(name: str, record): try: # os.kill runs signal handlers before returning, so we don't need # to worry that the handler will be delayed @@ -286,7 +286,7 @@ async def raiser(name, record): print("check 1") record_set: set[str] = set() - async def check_unprotected_kill(): + async def check_unprotected_kill() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(sleeper, "s1", record_set) nursery.start_soon(sleeper, "s2", record_set) @@ -301,7 +301,7 @@ async def check_unprotected_kill(): print("check 2") record_set = set() - async def check_protected_kill(): + async def check_protected_kill() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(sleeper, "s1", record_set) nursery.start_soon(sleeper, "s2", record_set) @@ -316,10 +316,10 @@ async def check_protected_kill(): # error, then kill) print("check 3") - async def check_kill_during_shutdown(): + async def check_kill_during_shutdown() -> None: token = _core.current_trio_token() - def kill_during_shutdown(): + def kill_during_shutdown() -> None: assert _core.currently_ki_protected() try: token.run_sync_soon(kill_during_shutdown) @@ -340,7 +340,7 @@ class InstrumentOfDeath(Instrument): def before_run(self) -> None: ki_self() - async def main_1(): + async def main_1() -> None: await _core.checkpoint() with pytest.raises(KeyboardInterrupt): @@ -350,7 +350,7 @@ async def main_1(): print("check 5") @_core.enable_ki_protection - async def main_2(): + async def main_2() -> None: assert _core.currently_ki_protected() ki_self() with pytest.raises(KeyboardInterrupt): @@ -362,7 +362,7 @@ async def main_2(): print("check 6") @_core.enable_ki_protection - async def main_3(): + async def main_3() -> None: assert _core.currently_ki_protected() ki_self() await _core.cancel_shielded_checkpoint() @@ -377,7 +377,7 @@ async def main_3(): print("check 7") @_core.enable_ki_protection - async def main_4(): + async def main_4() -> None: assert _core.currently_ki_protected() ki_self() task = _core.current_task() @@ -396,7 +396,7 @@ def abort(_: RaiseCancelT) -> Abort: print("check 8") @_core.enable_ki_protection - async def main_5(): + async def main_5() -> None: assert _core.currently_ki_protected() ki_self() task = _core.current_task() @@ -419,7 +419,7 @@ def abort(raise_cancel: RaiseCancelT) -> Abort: print("check 9") @_core.enable_ki_protection - async def main_6(): + async def main_6() -> None: ki_self() with pytest.raises(KeyboardInterrupt): @@ -430,7 +430,7 @@ async def main_6(): # restrict_keyboard_interrupt_to_checkpoints=True record_list = [] - async def main_7(): + async def main_7() -> None: # We're not KI protected... assert not _core.currently_ki_protected() ki_self() @@ -454,7 +454,7 @@ async def main_7(): print("check 11") @_core.enable_ki_protection - async def main_8(): + async def main_8() -> None: assert _core.currently_ki_protected() with _core.CancelScope() as cancel_scope: cancel_scope.cancel() @@ -469,16 +469,16 @@ async def main_8(): _core.run(main_8) -def test_ki_is_good_neighbor(): +def test_ki_is_good_neighbor() -> None: # in the unlikely event someone overwrites our signal handler, we leave # the overwritten one be try: orig = signal.getsignal(signal.SIGINT) - def my_handler(signum, frame): # pragma: no cover + def my_handler(signum, frame) -> None: # pragma: no cover pass - async def main(): + async def main() -> None: signal.signal(signal.SIGINT, my_handler) _core.run(main) @@ -490,7 +490,7 @@ async def main(): # Regression test for #461 # don't know if _active not being visible is a problem -def test_ki_with_broken_threads(): +def test_ki_with_broken_threads() -> None: thread = threading.main_thread() # scary! @@ -502,7 +502,7 @@ def test_ki_with_broken_threads(): del threading._active[thread.ident] # type: ignore[attr-defined] @_core.enable_ki_protection - async def inner(): + async def inner() -> None: assert signal.getsignal(signal.SIGINT) != signal.default_int_handler _core.run(inner) diff --git a/trio/_core/_tests/test_mock_clock.py b/trio/_core/_tests/test_mock_clock.py index 9c74df3334..4d655cd1a4 100644 --- a/trio/_core/_tests/test_mock_clock.py +++ b/trio/_core/_tests/test_mock_clock.py @@ -11,7 +11,7 @@ from .tutil import slow -def test_mock_clock(): +def test_mock_clock() -> None: REAL_NOW = 123.0 c = MockClock() c._real_clock = lambda: REAL_NOW @@ -55,7 +55,7 @@ def test_mock_clock(): assert c2.current_time() < 10 -async def test_mock_clock_autojump(mock_clock): +async def test_mock_clock_autojump(mock_clock) -> None: assert mock_clock.autojump_threshold == inf mock_clock.autojump_threshold = 0 @@ -95,7 +95,7 @@ async def test_mock_clock_autojump(mock_clock): await sleep(100000) -async def test_mock_clock_autojump_interference(mock_clock): +async def test_mock_clock_autojump_interference(mock_clock) -> None: mock_clock.autojump_threshold = 0.02 mock_clock2 = MockClock() @@ -112,7 +112,7 @@ async def test_mock_clock_autojump_interference(mock_clock): await sleep(100000) -def test_mock_clock_autojump_preset(): +def test_mock_clock_autojump_preset() -> None: # Check that we can set the autojump_threshold before the clock is # actually in use, and it gets picked up mock_clock = MockClock(autojump_threshold=0.1) @@ -122,7 +122,7 @@ def test_mock_clock_autojump_preset(): assert time.perf_counter() - real_start < 1 -async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock): +async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock) -> None: # Checks that autojump_threshold=0 doesn't interfere with # calling wait_all_tasks_blocked with the default cushion=0. @@ -130,11 +130,11 @@ async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock): record = [] - async def sleeper(): + async def sleeper() -> None: await sleep(100) record.append("yawn") - async def waiter(): + async def waiter() -> None: await wait_all_tasks_blocked() record.append("waiter woke") await sleep(1000) @@ -148,7 +148,9 @@ async def waiter(): @slow -async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(mock_clock): +async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero( + mock_clock, +) -> None: # Checks that autojump_threshold=0 doesn't interfere with # calling wait_all_tasks_blocked with a non-zero cushion. @@ -156,11 +158,11 @@ async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(mock_clo record = [] - async def sleeper(): + async def sleeper() -> None: await sleep(100) record.append("yawn") - async def waiter(): + async def waiter() -> None: await wait_all_tasks_blocked(1) record.append("waiter done") diff --git a/trio/_core/_tests/test_multierror.py b/trio/_core/_tests/test_multierror.py index 6d9fd2a568..60ac41ae14 100644 --- a/trio/_core/_tests/test_multierror.py +++ b/trio/_core/_tests/test_multierror.py @@ -9,6 +9,7 @@ import warnings from pathlib import Path from traceback import extract_tb, print_exception +from typing import List import pytest @@ -38,11 +39,11 @@ async def raise_nothashable(code): raise NotHashableException(code) -def raiser1(): +def raiser1() -> None: raiser1_2() -def raiser1_2(): +def raiser1_2() -> None: raiser1_3() @@ -50,7 +51,7 @@ def raiser1_3(): raise ValueError("raiser1_string") -def raiser2(): +def raiser2() -> None: raiser2_2() @@ -73,7 +74,7 @@ def get_tb(raiser): return get_exc(raiser).__traceback__ -def test_concat_tb(): +def test_concat_tb() -> None: tb1 = get_tb(raiser1) tb2 = get_tb(raiser2) @@ -98,7 +99,7 @@ def test_concat_tb(): assert extract_tb(get_tb(raiser2)) == entries2 -def test_MultiError(): +def test_MultiError() -> None: exc1 = get_exc(raiser1) exc2 = get_exc(raiser2) @@ -109,12 +110,12 @@ def test_MultiError(): assert "ValueError" in repr(m) with pytest.raises(TypeError): - MultiError(object()) + MultiError(object()) # type: ignore[arg-type] with pytest.raises(TypeError): - MultiError([KeyError(), ValueError]) + MultiError([KeyError(), ValueError]) # type: ignore[list-item] -def test_MultiErrorOfSingleMultiError(): +def test_MultiErrorOfSingleMultiError() -> None: # For MultiError([MultiError]), ensure there is no bad recursion by the # constructor where __init__ is called if __new__ returns a bare MultiError. exceptions = (KeyError(), ValueError()) @@ -124,7 +125,7 @@ def test_MultiErrorOfSingleMultiError(): assert b.exceptions == exceptions -async def test_MultiErrorNotHashable(): +async def test_MultiErrorNotHashable() -> None: exc1 = NotHashableException(42) exc2 = NotHashableException(4242) exc3 = ValueError() @@ -137,7 +138,7 @@ async def test_MultiErrorNotHashable(): nursery.start_soon(raise_nothashable, 4242) -def test_MultiError_filter_NotHashable(): +def test_MultiError_filter_NotHashable() -> None: excs = MultiError([NotHashableException(42), ValueError()]) def handle_ValueError(exc): @@ -173,7 +174,7 @@ def make_tree(): return MultiError([m12, exc3]) -def assert_tree_eq(m1, m2): +def assert_tree_eq(m1, m2) -> None: if m1 is None or m2 is None: assert m1 is m2 return @@ -240,7 +241,7 @@ def simple_filter(exc): + extract_tb(orig.exceptions[0].exceptions[1].__traceback__) ) - def p(exc): + def p(exc) -> None: print_exception(type(exc), exc, exc.__traceback__) p(orig) @@ -275,7 +276,7 @@ def filter_all(exc): def test_MultiError_catch(): # No exception to catch - def noop(_): + def noop(_) -> None: pass # pragma: no cover with pytest.warns(TrioDeprecationWarning), MultiError.catch(noop): @@ -391,7 +392,7 @@ def simple_filter(exc): gc.garbage.clear() -def assert_match_in_seq(pattern_list, string): +def assert_match_in_seq(pattern_list: List[str], string: str) -> None: offset = 0 print("looking for pattern matches...") for pattern in pattern_list: @@ -402,14 +403,14 @@ def assert_match_in_seq(pattern_list, string): offset = match.end() -def test_assert_match_in_seq(): +def test_assert_match_in_seq() -> None: assert_match_in_seq(["a", "b"], "xx a xx b xx") assert_match_in_seq(["b", "a"], "xx b xx a xx") with pytest.raises(AssertionError): assert_match_in_seq(["a", "b"], "xx b xx a xx") -def test_base_multierror(): +def test_base_multierror() -> None: """ Test that MultiError() with at least one base exception will return a MultiError object. @@ -419,7 +420,7 @@ def test_base_multierror(): assert type(exc) is MultiError -def test_non_base_multierror(): +def test_non_base_multierror() -> None: """ Test that MultiError() without base exceptions will return a NonBaseMultiError object. @@ -462,7 +463,7 @@ def run_script(name: str) -> subprocess.CompletedProcess[bytes]: not Path("/usr/lib/python3/dist-packages/apport_python_hook.py").exists(), reason="need Ubuntu with python3-apport installed", ) -def test_apport_excepthook_monkeypatch_interaction(): +def test_apport_excepthook_monkeypatch_interaction() -> None: completed = run_script("apport_excepthook.py") stdout = completed.stdout.decode("utf-8") diff --git a/trio/_core/_tests/test_parking_lot.py b/trio/_core/_tests/test_parking_lot.py index 3f03fdbade..2946a2fb3c 100644 --- a/trio/_core/_tests/test_parking_lot.py +++ b/trio/_core/_tests/test_parking_lot.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import TypeVar + import pytest from ... import _core @@ -5,11 +9,13 @@ from .._parking_lot import ParkingLot from .tutil import check_sequence_matches +T = TypeVar("T") + -async def test_parking_lot_basic(): +async def test_parking_lot_basic() -> None: record = [] - async def waiter(i, lot): + async def waiter(i, lot) -> None: record.append(f"sleep {i}") await lot.park() record.append(f"wake {i}") @@ -76,7 +82,9 @@ async def waiter(i, lot): lot.unpark(count=1.5) -async def cancellable_waiter(name, lot, scopes, record): +async def cancellable_waiter( + name: T, lot, scopes: dict[T, _core.CancelScope], record: list[str] +) -> None: with _core.CancelScope() as scope: scopes[name] = scope record.append(f"sleep {name}") @@ -88,9 +96,9 @@ async def cancellable_waiter(name, lot, scopes, record): record.append(f"wake {name}") -async def test_parking_lot_cancel(): - record = [] - scopes = {} +async def test_parking_lot_cancel() -> None: + record: list[str] = [] + scopes: dict[int, _core.CancelScope] = {} async with _core.open_nursery() as nursery: lot = ParkingLot() @@ -114,14 +122,14 @@ async def test_parking_lot_cancel(): ) -async def test_parking_lot_repark(): - record = [] - scopes = {} +async def test_parking_lot_repark() -> None: + record: list[str] = [] + scopes: dict[int, _core.CancelScope] = {} lot1 = ParkingLot() lot2 = ParkingLot() with pytest.raises(TypeError): - lot1.repark([]) + lot1.repark([]) # type: ignore[arg-type] async with _core.open_nursery() as nursery: nursery.start_soon(cancellable_waiter, 1, lot1, scopes, record) @@ -168,9 +176,9 @@ async def test_parking_lot_repark(): ] -async def test_parking_lot_repark_with_count(): - record = [] - scopes = {} +async def test_parking_lot_repark_with_count() -> None: + record: list[str] = [] + scopes: dict[int, _core.CancelScope] = {} lot1 = ParkingLot() lot2 = ParkingLot() async with _core.open_nursery() as nursery: diff --git a/trio/_core/_tests/test_run.py b/trio/_core/_tests/test_run.py index f67f83a4b8..8af166b9af 100644 --- a/trio/_core/_tests/test_run.py +++ b/trio/_core/_tests/test_run.py @@ -1878,7 +1878,7 @@ async def fail() -> NoReturn: async def test_nursery_stop_async_iteration() -> None: class it: - def __init__(self, count: int): + def __init__(self, count: int) -> None: self.count = count self.val = 0 @@ -1891,7 +1891,7 @@ async def __anext__(self) -> int: return val class async_zip: - def __init__(self, *largs: it): + def __init__(self, *largs: it) -> None: self.nexts = [obj.__anext__ for obj in largs] async def _accumulate( diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py index 3cd79ecd8a..4893ac3cae 100644 --- a/trio/_core/_tests/test_thread_cache.py +++ b/trio/_core/_tests/test_thread_cache.py @@ -14,13 +14,13 @@ from .tutil import gc_collect_harder, slow -def test_thread_cache_basics(): +def test_thread_cache_basics() -> None: q = Queue[Outcome]() def fn() -> NoReturn: raise RuntimeError("hi") - def deliver(outcome): + def deliver(outcome) -> None: q.put(outcome) start_thread_soon(fn, deliver) @@ -30,19 +30,19 @@ def deliver(outcome): outcome.unwrap() -def test_thread_cache_deref(): +def test_thread_cache_deref() -> None: res = [False] class del_me: - def __call__(self): + def __call__(self) -> int: return 42 - def __del__(self): + def __del__(self) -> None: res[0] = True - q = Queue() + q = Queue[Outcome]() - def deliver(outcome): + def deliver(outcome) -> None: q.put(outcome) start_thread_soon(del_me(), deliver) @@ -54,7 +54,7 @@ def deliver(outcome): @slow -def test_spawning_new_thread_from_deliver_reuses_starting_thread(): +def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None: # We know that no-one else is using the thread cache, so if we keep # submitting new jobs the instant the previous one is finished, we should # keep getting the same thread over and over. This tests both that the @@ -63,7 +63,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread(): # Make sure there are a few threads running, so if we weren't LIFO then we # could grab the wrong one. - q = Queue() + q = Queue[Outcome]() COUNT = 5 for _ in range(COUNT): start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result)) @@ -73,7 +73,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread(): seen_threads = set() done = threading.Event() - def deliver(n, _): + def deliver(n, _) -> None: print(n) seen_threads.add(threading.current_thread()) if n == 0: @@ -89,13 +89,13 @@ def deliver(n, _): @slow -def test_idle_threads_exit(monkeypatch): +def test_idle_threads_exit(monkeypatch) -> None: # Temporarily set the idle timeout to something tiny, to speed up the # test. (But non-zero, so that the worker loop will at least yield the # CPU.) monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001) - q = Queue() + q = Queue[threading.Thread]() start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread())) seen_thread = q.get() # Since the idle timeout is 0, after sleeping for 1 second, the thread @@ -116,7 +116,7 @@ def _join_started_threads(): assert not thread.is_alive() -def test_race_between_idle_exit_and_job_assignment(monkeypatch): +def test_race_between_idle_exit_and_job_assignment(monkeypatch) -> None: # This is a lock where the first few times you try to acquire it with a # timeout, it waits until the lock is available and then pretends to time # out. Using this in our thread cache implementation causes the following @@ -135,11 +135,11 @@ def test_race_between_idle_exit_and_job_assignment(monkeypatch): # everything proceeds as normal. class JankyLock: - def __init__(self): + def __init__(self) -> None: self._lock = threading.Lock() self._counter = 3 - def acquire(self, timeout=-1): + def acquire(self, timeout=-1) -> bool: got_it = self._lock.acquire(timeout=timeout) if timeout == -1: return True @@ -152,7 +152,7 @@ def acquire(self, timeout=-1): else: return False - def release(self): + def release(self) -> None: self._lock.release() monkeypatch.setattr(_thread_cache, "Lock", JankyLock) @@ -169,10 +169,10 @@ def release(self): tc.start_thread_soon(lambda: None, lambda _: None) -def test_raise_in_deliver(capfd): +def test_raise_in_deliver(capfd) -> None: seen_threads = set() - def track_threads(): + def track_threads() -> None: seen_threads.add(threading.current_thread()) def deliver(_): diff --git a/trio/_core/_tests/tutil.py b/trio/_core/_tests/tutil.py index 070af8ed15..1d49d8b262 100644 --- a/trio/_core/_tests/tutil.py +++ b/trio/_core/_tests/tutil.py @@ -9,7 +9,7 @@ import warnings from collections.abc import Generator, Iterable, Sequence from contextlib import closing, contextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar import pytest @@ -18,6 +18,8 @@ slow = pytest.mark.skipif(not RUN_SLOW, reason="use --run-slow to run slow tests") +T = TypeVar("T") + # PyPy 7.2 was released with a bug that just never called the async # generator 'firstiter' hook at all. This impacts tests of end-of-run # finalization (nothing gets added to runner.asyncgens) and tests of @@ -98,9 +100,7 @@ def restore_unraisablehook() -> Generator[None, None, None]: # template is like: # [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 2.2, 3] or [1, 2.2, 2.1, 3] -def check_sequence_matches( - seq: Sequence[object], template: Iterable[object | set[object]] -) -> None: +def check_sequence_matches(seq: Sequence[T], template: Iterable[T | set[T]]) -> None: i = 0 for pattern in template: if not isinstance(pattern, set): diff --git a/trio/_tests/test_file_io.py b/trio/_tests/test_file_io.py index 863ebe81b0..da808c5291 100644 --- a/trio/_tests/test_file_io.py +++ b/trio/_tests/test_file_io.py @@ -231,7 +231,7 @@ async def test_aclose_cancelled(path) -> None: async def test_detach_rewraps_asynciobase() -> None: raw = io.BytesIO() - buffered = io.BufferedReader(raw) + buffered = io.BufferedReader(raw) # type: ignore[arg-type] # ???????????? async_file = trio.wrap_file(buffered) diff --git a/trio/_tests/test_highlevel_generic.py b/trio/_tests/test_highlevel_generic.py index 64c5697184..5ccc69151f 100644 --- a/trio/_tests/test_highlevel_generic.py +++ b/trio/_tests/test_highlevel_generic.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import attr import pytest @@ -7,7 +9,7 @@ @attr.s class RecordSendStream(SendStream): - record = attr.ib(factory=list) + record: list[str | tuple[str, object]] = attr.ib(factory=list) async def send_all(self, data) -> None: self.record.append(("send_all", data)) @@ -21,9 +23,9 @@ async def aclose(self) -> None: @attr.s class RecordReceiveStream(ReceiveStream): - record = attr.ib(factory=list) + record: list[str | tuple[str, int | None]] = attr.ib(factory=list) - async def receive_some(self, max_bytes=None) -> None: + async def receive_some(self, max_bytes: int | None = None) -> None: # type: ignore[override] self.record.append(("receive_some", max_bytes)) async def aclose(self) -> None: @@ -53,7 +55,7 @@ async def test_StapledStream() -> None: async def fake_send_eof() -> None: send_stream.record.append("send_eof") - send_stream.send_eof = fake_send_eof + send_stream.send_eof = fake_send_eof # type: ignore[attr-defined] await stapled.send_eof() assert send_stream.record == ["send_eof"] diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py index beac59f86d..cf3c79f56f 100644 --- a/trio/_tests/test_highlevel_serve_listeners.py +++ b/trio/_tests/test_highlevel_serve_listeners.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import errno from functools import partial @@ -72,7 +74,9 @@ async def do_tests(parent_nursery) -> None: parent_nursery.cancel_scope.cancel() async with trio.open_nursery() as nursery: - l2 = await nursery.start(trio.serve_listeners, handler, listeners) + l2: list[MemoryListener] = await nursery.start( + trio.serve_listeners, handler, listeners + ) assert l2 == listeners # This is just split into another function because gh-136 isn't # implemented yet @@ -92,7 +96,7 @@ async def raise_error(): listener.accept_hook = raise_error with pytest.raises(type(error)) as excinfo: - await trio.serve_listeners(None, [listener]) + await trio.serve_listeners(None, [listener]) # type: ignore[arg-type] assert excinfo.value is error @@ -107,7 +111,7 @@ async def raise_EMFILE(): # It retries every 100 ms, so in 950 ms it will retry at 0, 100, ..., 900 # = 10 times total with trio.move_on_after(0.950): - await trio.serve_listeners(None, [listener]) + await trio.serve_listeners(None, [listener]) # type: ignore[arg-type] assert len(caplog.records) == 10 for record in caplog.records: @@ -133,7 +137,7 @@ async def connection_watcher(*, task_status=trio.TASK_STATUS_IGNORED): with pytest.raises(Done): async with trio.open_nursery() as nursery: - handler_nursery = await nursery.start(connection_watcher) + handler_nursery: trio.Nursery = await nursery.start(connection_watcher) await nursery.start( partial( trio.serve_listeners, diff --git a/trio/_tests/test_highlevel_ssl_helpers.py b/trio/_tests/test_highlevel_ssl_helpers.py index 9bec4ae2f5..afb5f30a6b 100644 --- a/trio/_tests/test_highlevel_ssl_helpers.py +++ b/trio/_tests/test_highlevel_ssl_helpers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from functools import partial import attr @@ -7,11 +9,13 @@ import trio.testing from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM +from .._highlevel_socket import SocketListener from .._highlevel_ssl_helpers import ( open_ssl_over_tcp_listeners, open_ssl_over_tcp_stream, serve_ssl_over_tcp, ) +from .._ssl import SSLListener # using noqa because linters don't understand how pytest fixtures work. from .test_ssl import SERVER_CTX, client_ctx # noqa: F401 @@ -48,11 +52,17 @@ async def test_open_ssl_over_tcp_stream_and_everything_else( client_ctx, # noqa: F811 # linters doesn't understand fixture ) -> None: async with trio.open_nursery() as nursery: - (listener,) = await nursery.start( + # TODO: the types are *very* funky here, this seems like an error in some signature + # unless this is doing stuff we don't want/expect end users to do + res: list[SSLListener] = await nursery.start( partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1") ) + (listener,) = res async with listener: - sockaddr = listener.transport_listener.socket.getsockname() + # listener.transport_listener is of type Listener[Stream] + tp_listener: SocketListener = listener.transport_listener # type: ignore[assignment] + + sockaddr = tp_listener.socket.getsockname() hostname_resolver = FakeHostnameResolver(sockaddr) trio.socket.set_custom_hostname_resolver(hostname_resolver) diff --git a/trio/_tests/test_path.py b/trio/_tests/test_path.py index bb24764c23..bfef1aaf2c 100644 --- a/trio/_tests/test_path.py +++ b/trio/_tests/test_path.py @@ -14,20 +14,20 @@ def path(tmpdir): return trio.Path(p) -def method_pair(path, method_name: str): +def method_pair(path, method_name): path = pathlib.Path(path) async_path = trio.Path(path) return getattr(path, method_name), getattr(async_path, method_name) -async def test_open_is_async_context_manager(path) -> None: +async def test_open_is_async_context_manager(path): async with await path.open("w") as f: assert isinstance(f, AsyncIOWrapper) assert f.closed -async def test_magic() -> None: +async def test_magic(): path = trio.Path("test") assert str(path) == "test" @@ -42,7 +42,7 @@ async def test_magic() -> None: @pytest.mark.parametrize("cls_a,cls_b", cls_pairs) -async def test_cmp_magic(cls_a, cls_b) -> None: +async def test_cmp_magic(cls_a, cls_b): a, b = cls_a(""), cls_b("") assert a == b assert not a != b @@ -69,7 +69,7 @@ async def test_cmp_magic(cls_a, cls_b) -> None: @pytest.mark.parametrize("cls_a,cls_b", cls_pairs) -async def test_div_magic(cls_a, cls_b) -> None: +async def test_div_magic(cls_a, cls_b): a, b = cls_a("a"), cls_b("b") result = a / b @@ -81,19 +81,19 @@ async def test_div_magic(cls_a, cls_b) -> None: "cls_a,cls_b", [(trio.Path, pathlib.Path), (trio.Path, trio.Path)] ) @pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"]) -async def test_hash_magic(cls_a, cls_b, path) -> None: +async def test_hash_magic(cls_a, cls_b, path): a, b = cls_a(path), cls_b(path) assert hash(a) == hash(b) -async def test_forwarded_properties(path) -> None: +async def test_forwarded_properties(path): # use `name` as a representative of forwarded properties assert "name" in dir(path) assert path.name == "test" -async def test_async_method_signature(path) -> None: +async def test_async_method_signature(path): # use `resolve` as a representative of wrapped methods assert path.resolve.__name__ == "resolve" @@ -103,7 +103,7 @@ async def test_async_method_signature(path) -> None: @pytest.mark.parametrize("method_name", ["is_dir", "is_file"]) -async def test_compare_async_stat_methods(method_name: str) -> None: +async def test_compare_async_stat_methods(method_name): method, async_method = method_pair(".", method_name) result = method() @@ -112,13 +112,13 @@ async def test_compare_async_stat_methods(method_name: str) -> None: assert result == async_result -async def test_invalid_name_not_wrapped(path) -> None: +async def test_invalid_name_not_wrapped(path): with pytest.raises(AttributeError): getattr(path, "invalid_fake_attr") @pytest.mark.parametrize("method_name", ["absolute", "resolve"]) -async def test_async_methods_rewrap(method_name: str) -> None: +async def test_async_methods_rewrap(method_name): method, async_method = method_pair(".", method_name) result = method() @@ -128,7 +128,7 @@ async def test_async_methods_rewrap(method_name: str) -> None: assert str(result) == str(async_result) -async def test_forward_methods_rewrap(path, tmpdir) -> None: +async def test_forward_methods_rewrap(path, tmpdir): with_name = path.with_name("foo") with_suffix = path.with_suffix(".py") @@ -138,17 +138,17 @@ async def test_forward_methods_rewrap(path, tmpdir) -> None: assert with_suffix == tmpdir.join("test.py") -async def test_forward_properties_rewrap(path) -> None: +async def test_forward_properties_rewrap(path): assert isinstance(path.parent, trio.Path) -async def test_forward_methods_without_rewrap(path, tmpdir) -> None: +async def test_forward_methods_without_rewrap(path, tmpdir): path = await path.parent.resolve() assert path.as_uri().startswith("file:///") -async def test_repr() -> None: +async def test_repr(): path = trio.Path(".") assert repr(path) == "trio.Path('.')" @@ -164,30 +164,30 @@ class MockWrapper: _wraps = MockWrapped -async def test_type_forwards_unsupported() -> None: +async def test_type_forwards_unsupported(): with pytest.raises(TypeError): Type.generate_forwards(MockWrapper, {}) -async def test_type_wraps_unsupported() -> None: +async def test_type_wraps_unsupported(): with pytest.raises(TypeError): Type.generate_wraps(MockWrapper, {}) -async def test_type_forwards_private() -> None: +async def test_type_forwards_private(): Type.generate_forwards(MockWrapper, {"unsupported": None}) assert not hasattr(MockWrapper, "_private") -async def test_type_wraps_private() -> None: +async def test_type_wraps_private(): Type.generate_wraps(MockWrapper, {"unsupported": None}) assert not hasattr(MockWrapper, "_private") @pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath]) -async def test_path_wraps_path(path, meth) -> None: +async def test_path_wraps_path(path, meth): wrapped = await path.absolute() result = meth(path, wrapped) if result is None: @@ -196,17 +196,17 @@ async def test_path_wraps_path(path, meth) -> None: assert wrapped == result -async def test_path_nonpath() -> None: +async def test_path_nonpath(): with pytest.raises(TypeError): trio.Path(1) -async def test_open_file_can_open_path(path) -> None: +async def test_open_file_can_open_path(path): async with await trio.open_file(path, "w") as f: assert f.name == os.fspath(path) -async def test_globmethods(path) -> None: +async def test_globmethods(path): # Populate a directory tree await path.mkdir() await (path / "foo").mkdir() @@ -235,7 +235,7 @@ async def test_globmethods(path) -> None: assert entries == {"_bar.txt", "bar.txt"} -async def test_iterdir(path) -> None: +async def test_iterdir(path): # Populate a directory await path.mkdir() await (path / "foo").mkdir() @@ -249,7 +249,7 @@ async def test_iterdir(path) -> None: assert entries == {"bar.txt", "foo"} -async def test_classmethods() -> None: +async def test_classmethods(): assert isinstance(await trio.Path.home(), trio.Path) # pathlib.Path has only two classmethods diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py index 8af210a0ff..fadddfdaaa 100644 --- a/trio/_tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -21,6 +21,7 @@ import trio from .. import _core, socket as tsocket +from .._abc import Stream from .._core import BrokenResourceError, ClosedResourceError from .._core._tests.tutil import slow from .._highlevel_generic import aclose_forcefully @@ -737,10 +738,20 @@ async def do_wait_send_all_might_not_block() -> None: async def test_wait_writable_calls_underlying_wait_writable() -> None: record = [] - class NotAStream: + class NotAStream(Stream): async def wait_send_all_might_not_block(self) -> None: record.append("ok") + # define methods that are abstract in Stream + async def aclose(self) -> None: + raise AssertionError("Should not get called") + + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: + raise AssertionError("Should not get called") + + async def send_all(self, data: bytes | bytearray | memoryview) -> None: + raise AssertionError("Should not get called") + ctx = ssl.create_default_context() s = SSLStream(NotAStream(), ctx, server_hostname="x") await s.wait_send_all_might_not_block() diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index 0f993976e0..f511ff021e 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -153,6 +153,8 @@ async def test_async_with_basics_deprecated(recwarn) -> None: ) as proc: pass assert proc.returncode is not None + assert proc.stdin is not None + assert proc.stdout is not None with pytest.raises(ClosedResourceError): await proc.stdin.send_all(b"x") with pytest.raises(ClosedResourceError): @@ -418,7 +420,7 @@ async def test_stderr_stdout(background_process) -> None: async def test_errors() -> None: with pytest.raises(TypeError) as excinfo: - await open_process(["ls"], encoding="utf-8") + await open_process(["ls"], encoding="utf-8") # type: ignore[call-overload] assert "unbuffered byte streams" in str(excinfo.value) assert "the 'encoding' option is not supported" in str(excinfo.value) @@ -563,7 +565,9 @@ async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch) -> async def test_run_process_background_fail() -> None: with pytest.raises(subprocess.CalledProcessError): async with _core.open_nursery() as nursery: - proc = await nursery.start(run_process, EXIT_FALSE) + proc: subprocess.CompletedProcess[bytes] = await nursery.start( + run_process, EXIT_FALSE + ) assert proc.returncode == 1 diff --git a/trio/_tests/test_sync.py b/trio/_tests/test_sync.py index 7747740c9f..91c865e085 100644 --- a/trio/_tests/test_sync.py +++ b/trio/_tests/test_sync.py @@ -176,7 +176,7 @@ async def test_CapacityLimiter_memleak_548() -> None: async def test_Semaphore() -> None: with pytest.raises(TypeError): - Semaphore(1.0) + Semaphore(1.0) # type: ignore[arg-type] with pytest.raises(ValueError): Semaphore(-1) s = Semaphore(1) @@ -224,7 +224,7 @@ async def do_acquire(s) -> None: async def test_Semaphore_bounded() -> None: with pytest.raises(TypeError): - Semaphore(1, max_value=1.0) + Semaphore(1, max_value=1.0) # type: ignore[arg-type] with pytest.raises(ValueError): Semaphore(2, max_value=1) bs = Semaphore(1, max_value=1) @@ -317,9 +317,9 @@ async def holder() -> None: async def test_Condition() -> None: with pytest.raises(TypeError): - Condition(Semaphore(1)) + Condition(Semaphore(1)) # type: ignore[arg-type] with pytest.raises(TypeError): - Condition(StrictFIFOLock) + Condition(StrictFIFOLock) # type: ignore[arg-type] l = Lock() # noqa c = Condition(l) assert not l.locked() @@ -407,8 +407,8 @@ async def waiter(i) -> None: class ChannelLock1(AsyncContextManagerMixin): - def __init__(self, capacity) -> None: - self.s, self.r = open_memory_channel(capacity) + def __init__(self, capacity: int) -> None: + self.s, self.r = open_memory_channel[None](capacity) for _ in range(capacity - 1): self.s.send_nowait(None) @@ -424,7 +424,7 @@ def release(self) -> None: class ChannelLock2(AsyncContextManagerMixin): def __init__(self) -> None: - self.s, self.r = open_memory_channel(10) + self.s, self.r = open_memory_channel[None](10) self.s.send_nowait(None) def acquire_nowait(self) -> None: @@ -439,7 +439,7 @@ def release(self) -> None: class ChannelLock3(AsyncContextManagerMixin): def __init__(self) -> None: - self.s, self.r = open_memory_channel(0) + self.s, self.r = open_memory_channel[None](0) # self.acquired is true when one task acquires the lock and # only becomes false when it's released and no tasks are # waiting to acquire. diff --git a/trio/_tests/test_testing.py b/trio/_tests/test_testing.py index e7e0f95535..fd953e47bf 100644 --- a/trio/_tests/test_testing.py +++ b/trio/_tests/test_testing.py @@ -261,7 +261,7 @@ async def test__UnboundeByteQueue() -> None: ubq.get_nowait() with pytest.raises(TypeError): - ubq.put("string") + ubq.put("string") # type: ignore[arg-type] ubq.put(b"abc") with assert_checkpoints(): diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 246b50533f..860bc977c0 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -8,7 +8,7 @@ import time import weakref from functools import partial -from typing import Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional import pytest import sniffio @@ -31,7 +31,7 @@ async def test_do_in_trio_thread() -> None: trio_thread = threading.current_thread() async def check_case(do_in_trio_thread, fn, expected, trio_token=None) -> None: - record = [] + record: list[tuple[str, threading.Thread | type[BaseException]]] = [] def threadfn() -> None: try: @@ -51,35 +51,35 @@ def threadfn() -> None: token = _core.current_trio_token() - def f(record) -> int: + def f1(record) -> int: assert not _core.currently_ki_protected() record.append(("f", threading.current_thread())) return 2 - await check_case(from_thread_run_sync, f, ("got", 2), trio_token=token) + await check_case(from_thread_run_sync, f1, ("got", 2), trio_token=token) - def f(record): + def f2(record): assert not _core.currently_ki_protected() record.append(("f", threading.current_thread())) raise ValueError - await check_case(from_thread_run_sync, f, ("error", ValueError), trio_token=token) + await check_case(from_thread_run_sync, f2, ("error", ValueError), trio_token=token) - async def f(record) -> int: + async def f3(record) -> int: assert not _core.currently_ki_protected() await _core.checkpoint() record.append(("f", threading.current_thread())) return 3 - await check_case(from_thread_run, f, ("got", 3), trio_token=token) + await check_case(from_thread_run, f3, ("got", 3), trio_token=token) - async def f(record): + async def f4(record): assert not _core.currently_ki_protected() await _core.checkpoint() record.append(("f", threading.current_thread())) raise KeyError - await check_case(from_thread_run, f, ("error", KeyError), trio_token=token) + await check_case(from_thread_run, f4, ("error", KeyError), trio_token=token) async def test_do_in_trio_thread_from_trio_thread() -> None: @@ -300,7 +300,7 @@ def g(): async def test_run_in_worker_thread_cancellation() -> None: - register = [None] + register: list[str | None] = [None] def f(q) -> None: # Make the thread block for a controlled amount of time @@ -315,8 +315,8 @@ async def child(q, cancellable): finally: record.append("exit") - record = [] - q = stdlib_queue.Queue() + record: list[str] = [] + q = stdlib_queue.Queue[None]() async with _core.open_nursery() as nursery: nursery.start_soon(child, q, True) # Give it a chance to get started. (This is important because @@ -362,8 +362,8 @@ async def child(q, cancellable): def test_run_in_worker_thread_abandoned(capfd, monkeypatch) -> None: monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01) - q1 = stdlib_queue.Queue() - q2 = stdlib_queue.Queue() + q1 = stdlib_queue.Queue[None]() + q2 = stdlib_queue.Queue[threading.Thread]() def thread_fn() -> None: q1.get() @@ -427,6 +427,11 @@ async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter) -> # Mutating them in-place is OK though (as long as you use proper # locking etc.). class state: + if TYPE_CHECKING: + ran: int + high_water: int + running: int + parked: int pass state.ran = 0 @@ -515,7 +520,9 @@ def release_on_behalf_of(self, borrower) -> None: record.append("release") assert borrower == self._borrower - await to_thread_run_sync(lambda: None, limiter=CustomLimiter()) + # TODO: should CapacityLimiter have an abc or protocol so users can modify it? + # because currently it's `final` so writing code like this is not allowed. + await to_thread_run_sync(lambda: None, limiter=CustomLimiter()) # type: ignore[arg-type] assert record == ["acquire", "release"] @@ -533,16 +540,16 @@ def release_on_behalf_of(self, borrower): bs = BadCapacityLimiter() with pytest.raises(ValueError) as excinfo: - await to_thread_run_sync(lambda: None, limiter=bs) + await to_thread_run_sync(lambda: None, limiter=bs) # type: ignore[arg-type] assert excinfo.value.__context__ is None assert record == ["acquire", "release"] record = [] # If the original function raised an error, then the semaphore error # chains with it - d = {} + d: dict[str, object] = {} with pytest.raises(ValueError) as excinfo: - await to_thread_run_sync(lambda: d["x"], limiter=bs) + await to_thread_run_sync(lambda: d["x"], limiter=bs) # type: ignore[arg-type] assert isinstance(excinfo.value.__context__, KeyError) assert record == ["acquire", "release"] @@ -583,7 +590,7 @@ async def async_fn() -> None: # pragma: no cover pass with pytest.raises(TypeError, match="expected a sync function"): - await to_thread_run_sync(async_fn) + await to_thread_run_sync(async_fn) # type: ignore[unused-coroutine] trio_test_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar( @@ -631,22 +638,22 @@ def g(): async def test_trio_from_thread_run_sync() -> None: # Test that to_thread_run_sync correctly "hands off" the trio token to # trio.from_thread.run_sync() - def thread_fn(): + def thread_fn_1(): trio_time = from_thread_run_sync(_core.current_time) return trio_time - trio_time = await to_thread_run_sync(thread_fn) + trio_time = await to_thread_run_sync(thread_fn_1) assert isinstance(trio_time, float) # Test correct error when passed async function async def async_fn() -> None: # pragma: no cover pass - def thread_fn() -> None: - from_thread_run_sync(async_fn) + def thread_fn_2() -> None: + from_thread_run_sync(async_fn) # type: ignore[unused-coroutine] with pytest.raises(TypeError, match="expected a sync function"): - await to_thread_run_sync(thread_fn) + await to_thread_run_sync(thread_fn_2) async def test_trio_from_thread_run() -> None: @@ -793,7 +800,10 @@ async def async_back_in_main(): def test_run_fn_as_system_task_catched_badly_typed_token() -> None: with pytest.raises(RuntimeError): - from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype") + from_thread_run_sync( + _core.current_time, + trio_token="Not TrioTokentype", # type: ignore[arg-type] + ) async def test_from_thread_inside_trio_thread() -> None: @@ -841,4 +851,4 @@ def __bool__(self) -> bool: raise NotImplementedError with pytest.raises(NotImplementedError): - await to_thread_run_sync(int, cancellable=BadBool()) + await to_thread_run_sync(int, cancellable=BadBool()) # type: ignore[arg-type] diff --git a/trio/_tests/test_util.py b/trio/_tests/test_util.py index ef99e8f66f..f074a06096 100644 --- a/trio/_tests/test_util.py +++ b/trio/_tests/test_util.py @@ -108,7 +108,7 @@ async def f() -> None: # pragma: no cover pass with pytest.raises(TypeError) as excinfo: - coroutine_or_error(f()) + coroutine_or_error(f()) # type: ignore[arg-type, unused-coroutine] assert "expecting an async function" in str(excinfo.value) import asyncio @@ -120,35 +120,37 @@ def generator_based_coro(): # pragma: no cover yield from asyncio.sleep(1) with pytest.raises(TypeError) as excinfo: - coroutine_or_error(generator_based_coro()) + coroutine_or_error(generator_based_coro()) # type: ignore[arg-type, unused-coroutine] assert "asyncio" in str(excinfo.value) with pytest.raises(TypeError) as excinfo: - coroutine_or_error(create_asyncio_future_in_new_loop()) + coroutine_or_error(create_asyncio_future_in_new_loop()) # type: ignore[arg-type, unused-coroutine] assert "asyncio" in str(excinfo.value) + # does not raise arg-type error with pytest.raises(TypeError) as excinfo: - coroutine_or_error(create_asyncio_future_in_new_loop) + coroutine_or_error(create_asyncio_future_in_new_loop) # type: ignore[unused-coroutine] assert "asyncio" in str(excinfo.value) with pytest.raises(TypeError) as excinfo: - coroutine_or_error(Deferred()) + coroutine_or_error(Deferred()) # type: ignore[arg-type, unused-coroutine] assert "twisted" in str(excinfo.value) with pytest.raises(TypeError) as excinfo: - coroutine_or_error(lambda: Deferred()) + coroutine_or_error(lambda: Deferred()) # type: ignore[arg-type, unused-coroutine, return-value] assert "twisted" in str(excinfo.value) with pytest.raises(TypeError) as excinfo: - coroutine_or_error(len, [[1, 2, 3]]) + coroutine_or_error(len, [[1, 2, 3]]) # type: ignore[arg-type, unused-coroutine] assert "appears to be synchronous" in str(excinfo.value) async def async_gen(arg): # pragma: no cover yield + # does not give arg-type typing error with pytest.raises(TypeError) as excinfo: - coroutine_or_error(async_gen, [0]) + coroutine_or_error(async_gen, [0]) # type: ignore[unused-coroutine] msg = "expected an async function but got an async generator" assert msg in str(excinfo.value) @@ -165,8 +167,8 @@ def test_func(arg): assert test_func is test_func[int] is test_func[int, str] assert test_func(42) == test_func[int](42) == 42 assert test_func.__doc__ == "Look, a docstring!" - assert test_func.__qualname__ == "test_generic_function..test_func" - assert test_func.__name__ == "test_func" + assert test_func.__qualname__ == "test_generic_function..test_func" # type: ignore[attr-defined] + assert test_func.__name__ == "test_func" # type: ignore[attr-defined] assert test_func.__module__ == __name__ @@ -206,7 +208,7 @@ def __init__(self, a: int, b: float) -> None: def test_fixup_module_metadata() -> None: # Ignores modules not in the trio.X tree. non_trio_module = types.ModuleType("not_trio") - non_trio_module.some_func = lambda: None + non_trio_module.some_func = lambda: None # type: ignore[attr-defined] non_trio_module.some_func.__name__ = "some_func" non_trio_module.some_func.__qualname__ = "some_func" @@ -217,26 +219,26 @@ def test_fixup_module_metadata() -> None: # Bulild up a fake module to test. Just use lambdas since all we care about is the names. mod = types.ModuleType("trio._somemodule_impl") - mod.some_func = lambda: None + mod.some_func = lambda: None # type: ignore[attr-defined] mod.some_func.__name__ = "_something_else" mod.some_func.__qualname__ = "_something_else" # No __module__ means it's unchanged. - mod.not_funclike = types.SimpleNamespace() + mod.not_funclike = types.SimpleNamespace() # type: ignore[attr-defined] mod.not_funclike.__name__ = "not_funclike" # Check __qualname__ being absent works. - mod.only_has_name = types.SimpleNamespace() + mod.only_has_name = types.SimpleNamespace() # type: ignore[attr-defined] mod.only_has_name.__module__ = "trio._somemodule_impl" mod.only_has_name.__name__ = "only_name" # Underscored names are unchanged. - mod._private = lambda: None + mod._private = lambda: None # type: ignore[attr-defined] mod._private.__module__ = "trio._somemodule_impl" mod._private.__name__ = mod._private.__qualname__ = "_private" # We recurse into classes. - mod.SomeClass = type( + mod.SomeClass = type( # type: ignore[attr-defined] "SomeClass", (), { @@ -244,7 +246,8 @@ def test_fixup_module_metadata() -> None: "method": lambda self: None, }, ) - mod.SomeClass.recursion = mod.SomeClass # Reference loop is fine. + # Reference loop is fine. + mod.SomeClass.recursion = mod.SomeClass # type: ignore[attr-defined] fixup_module_metadata("trio.somemodule", vars(mod)) assert mod.some_func.__name__ == "some_func" @@ -260,9 +263,9 @@ def test_fixup_module_metadata() -> None: assert mod.only_has_name.__module__ == "trio.somemodule" assert not hasattr(mod.only_has_name, "__qualname__") - assert mod.SomeClass.method.__name__ == "method" - assert mod.SomeClass.method.__module__ == "trio.somemodule" - assert mod.SomeClass.method.__qualname__ == "SomeClass.method" + assert mod.SomeClass.method.__name__ == "method" # type: ignore[attr-defined] + assert mod.SomeClass.method.__module__ == "trio.somemodule" # type: ignore[attr-defined] + assert mod.SomeClass.method.__qualname__ == "SomeClass.method" # type: ignore[attr-defined] # Make coverage happy. non_trio_module.some_func() mod.some_func() diff --git a/trio/_tests/test_windows_pipes.py b/trio/_tests/test_windows_pipes.py index 399d7116bb..5c4bae7d25 100644 --- a/trio/_tests/test_windows_pipes.py +++ b/trio/_tests/test_windows_pipes.py @@ -24,14 +24,14 @@ async def make_pipe() -> Tuple[PipeSendStream, PipeReceiveStream]: return PipeSendStream(w), PipeReceiveStream(r) -async def test_pipe_typecheck() -> None: +async def test_pipe_typecheck(): with pytest.raises(TypeError): PipeSendStream(1.0) with pytest.raises(TypeError): PipeReceiveStream(None) -async def test_pipe_error_on_close() -> None: +async def test_pipe_error_on_close(): # Make sure we correctly handle a failure from kernel32.CloseHandle r, w = pipe() @@ -47,18 +47,18 @@ async def test_pipe_error_on_close() -> None: await receive_stream.aclose() -async def test_pipes_combined() -> None: +async def test_pipes_combined(): write, read = await make_pipe() count = 2**20 replicas = 3 - async def sender() -> None: + async def sender(): async with write: big = bytearray(count) for _ in range(replicas): await write.send_all(big) - async def reader() -> None: + async def reader(): async with read: await wait_all_tasks_blocked() total_received = 0 @@ -76,7 +76,7 @@ async def reader() -> None: n.start_soon(reader) -async def test_async_with() -> None: +async def test_async_with(): w, r = await make_pipe() async with w, r: pass @@ -87,11 +87,11 @@ async def test_async_with() -> None: await r.receive_some(10) -async def test_close_during_write() -> None: +async def test_close_during_write(): w, r = await make_pipe() async with _core.open_nursery() as nursery: - async def write_forever() -> None: + async def write_forever(): with pytest.raises(_core.ClosedResourceError) as excinfo: while True: await w.send_all(b"x" * 4096) @@ -102,7 +102,7 @@ async def write_forever() -> None: await w.aclose() -async def test_pipe_fully() -> None: +async def test_pipe_fully(): # passing make_clogged_pipe tests wait_send_all_might_not_block, and we # can't implement that on Windows await check_one_way_stream(make_pipe, None) From 9dd47b3bba6b679dc6b8ed043906f59e346acb66 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 17 Oct 2023 13:42:54 +0200 Subject: [PATCH 05/35] enable check_untyped_defs --- pyproject.toml | 1 - trio/_tests/test_highlevel_serve_listeners.py | 2 +- trio/_tests/test_ssl.py | 4 +++- trio/_tests/test_subprocess.py | 11 ++++++++--- trio/_tests/test_testing.py | 2 +- trio/_tests/test_tracing.py | 15 +++++++++++---- 6 files changed, 24 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 96e8b008bf..4ef771fcd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,7 +112,6 @@ module = [ "trio/_tests/test_wait_for_object", "trio/_tests/tools/test_gen_exports", ] -check_untyped_defs = false disallow_any_decorated = false disallow_any_generics = false disallow_any_unimported = false diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py index cf3c79f56f..70b01042b2 100644 --- a/trio/_tests/test_highlevel_serve_listeners.py +++ b/trio/_tests/test_highlevel_serve_listeners.py @@ -13,7 +13,7 @@ @attr.s(hash=False, eq=False) class MemoryListener(trio.abc.Listener): closed = attr.ib(default=False) - accepted_streams = attr.ib(factory=list) + accepted_streams: list[trio.abc.Stream] = attr.ib(factory=list) queued_streams = attr.ib( factory=(lambda: trio.open_memory_channel[trio.StapledStream](1)) ) diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py index fadddfdaaa..85d7f43069 100644 --- a/trio/_tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -167,7 +167,9 @@ async def ssl_echo_server(client_ctx, **kwargs): # The weird in-memory server ... thing. # Doesn't inherit from Stream because I left out the methods that we don't # actually need. -class PyOpenSSLEchoStream: +# jakkdl: it seems to implement all the abstract methods (now), so I made it inherit +# from Stream for the sake of typechecking. +class PyOpenSSLEchoStream(Stream): def __init__(self, sleeper=None) -> None: ctx = SSL.Context(SSL.SSLv23_METHOD) # TLS 1.3 removes renegotiation support. Which is great for them, but diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index f511ff021e..d6bf47ce82 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -8,7 +8,7 @@ from contextlib import asynccontextmanager from functools import partial from pathlib import Path as SyncPath -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, AsyncIterator import pytest @@ -80,11 +80,16 @@ async def open_process_then_kill(*args, **kwargs): await proc.wait() +# not entirely sure about this annotation @asynccontextmanager -async def run_process_in_nursery(*args, **kwargs): +async def run_process_in_nursery( + *args, **kwargs +) -> AsyncIterator[subprocess.CompletedProcess[bytes]]: async with _core.open_nursery() as nursery: kwargs.setdefault("check", False) - proc = await nursery.start(partial(run_process, *args, **kwargs)) + proc: subprocess.CompletedProcess[bytes] = await nursery.start( + partial(run_process, *args, **kwargs) + ) yield proc nursery.cancel_scope.cancel() diff --git a/trio/_tests/test_testing.py b/trio/_tests/test_testing.py index fd953e47bf..08f896b725 100644 --- a/trio/_tests/test_testing.py +++ b/trio/_tests/test_testing.py @@ -237,7 +237,7 @@ async def test__assert_raises(): with pytest.raises(TypeError): with _assert_raises(RuntimeError): - "foo" + 1 + "foo" + 1 # type: ignore[operator] with _assert_raises(RuntimeError): raise RuntimeError diff --git a/trio/_tests/test_tracing.py b/trio/_tests/test_tracing.py index 0cef2b0f44..29fbb7a475 100644 --- a/trio/_tests/test_tracing.py +++ b/trio/_tests/test_tracing.py @@ -1,3 +1,5 @@ +from typing import AsyncGenerator + import trio @@ -14,10 +16,15 @@ async def coro3(event: trio.Event) -> None: await coro2(event) -async def coro2_async_gen(event): - yield await trio.lowlevel.checkpoint() - yield await coro1(event) - yield await trio.lowlevel.checkpoint() +async def coro2_async_gen(event) -> AsyncGenerator[None, None]: + # mypy does not like `yield await trio.lowlevel.checkpoint()` - but that + # should be equivalent to splitting the statement + await trio.lowlevel.checkpoint() + yield + await coro1(event) + yield + await trio.lowlevel.checkpoint() + yield async def coro3_async_gen(event: trio.Event) -> None: From 787b41e2b5ea6453c350a1a06bf8774a72267bef Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 17 Oct 2023 15:16:19 +0200 Subject: [PATCH 06/35] enable disallow_any_generics, fully type highlevel_serve_listers --- pyproject.toml | 2 - trio/_core/_tests/test_guest_mode.py | 4 +- trio/_tests/test_highlevel_serve_listeners.py | 61 +++++++++++++------ 3 files changed, 45 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4ef771fcd7..da22a6d271 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,6 @@ module = [ "trio/_tests/test_file_io", "trio/_tests/test_highlevel_generic", "trio/_tests/test_highlevel_open_unix_stream", -"trio/_tests/test_highlevel_serve_listeners", "trio/_tests/test_highlevel_ssl_helpers", "trio/_tests/test_scheduler_determinism", "trio/_tests/test_ssl", @@ -113,7 +112,6 @@ module = [ "trio/_tests/tools/test_gen_exports", ] disallow_any_decorated = false -disallow_any_generics = false disallow_any_unimported = false disallow_incomplete_defs = false disallow_untyped_defs = false diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py index 3789cdefa3..bb3eb5e01a 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -35,7 +35,7 @@ def trivial_guest_run(trio_fn, *, in_host_after_start=None, **start_guest_run_kw host_thread = threading.current_thread() - def run_sync_soon_threadsafe(fn: Callable) -> None: + def run_sync_soon_threadsafe(fn: Callable[[], object]) -> None: nonlocal todo if host_thread is threading.current_thread(): # pragma: no cover crash = partial( @@ -44,7 +44,7 @@ def run_sync_soon_threadsafe(fn: Callable) -> None: todo.put(("run", crash)) todo.put(("run", fn)) - def run_sync_soon_not_threadsafe(fn: Callable) -> None: + def run_sync_soon_not_threadsafe(fn: Callable[[], object]) -> None: nonlocal todo if host_thread is not threading.current_thread(): # pragma: no cover crash = partial( diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py index 70b01042b2..2e333068ec 100644 --- a/trio/_tests/test_highlevel_serve_listeners.py +++ b/trio/_tests/test_highlevel_serve_listeners.py @@ -1,31 +1,49 @@ from __future__ import annotations import errno +import sys from functools import partial +from typing import Awaitable, Callable, NoReturn import attr import pytest import trio -from trio.testing import memory_stream_pair, wait_all_tasks_blocked +from trio import Nursery, StapledStream, TaskStatus +from trio._channel import MemoryReceiveChannel, MemorySendChannel +from trio.abc import Stream +from trio.testing import ( + MemoryReceiveStream, + MemorySendStream, + MockClock, + memory_stream_pair, + wait_all_tasks_blocked, +) + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + +# types are somewhat tentative - I just bruteforced them until I got something that didn't +# give errors +TypeThing = StapledStream[MemorySendStream, MemoryReceiveStream] @attr.s(hash=False, eq=False) -class MemoryListener(trio.abc.Listener): - closed = attr.ib(default=False) +class MemoryListener(trio.abc.Listener[TypeThing]): + closed: bool = attr.ib(default=False) accepted_streams: list[trio.abc.Stream] = attr.ib(factory=list) - queued_streams = attr.ib( - factory=(lambda: trio.open_memory_channel[trio.StapledStream](1)) - ) - accept_hook = attr.ib(default=None) + queued_streams: tuple[ + MemorySendChannel[TypeThing], MemoryReceiveChannel[TypeThing] + ] = attr.ib(factory=(lambda: trio.open_memory_channel[TypeThing](1))) + accept_hook: Callable[[], Awaitable[object]] | None = attr.ib(default=None) - async def connect(self): + async def connect(self) -> StapledStream[MemorySendStream, MemoryReceiveStream]: assert not self.closed client, server = memory_stream_pair() await self.queued_streams[0].send(server) return client - async def accept(self): + async def accept(self) -> TypeThing: await trio.lowlevel.checkpoint() assert not self.closed if self.accept_hook is not None: @@ -49,18 +67,20 @@ def close_hook() -> None: assert trio.current_effective_deadline() == float("-inf") record.append("closed") - async def handler(stream) -> None: + async def handler( + stream: StapledStream[MemorySendStream, MemoryReceiveStream] + ) -> None: await stream.send_all(b"123") assert await stream.receive_some(10) == b"456" stream.send_stream.close_hook = close_hook stream.receive_stream.close_hook = close_hook - async def client(listener) -> None: + async def client(listener: MemoryListener) -> None: s = await listener.connect() assert await s.receive_some(10) == b"123" await s.send_all(b"456") - async def do_tests(parent_nursery) -> None: + async def do_tests(parent_nursery: Nursery) -> None: async with trio.open_nursery() as nursery: for listener in listeners: for _ in range(3): @@ -90,7 +110,7 @@ async def test_serve_listeners_accept_unrecognized_error() -> None: for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]: listener = MemoryListener() - async def raise_error(): + async def raise_error() -> NoReturn: raise error listener.accept_hook = raise_error @@ -100,10 +120,12 @@ async def raise_error(): assert excinfo.value is error -async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog) -> None: +async def test_serve_listeners_accept_capacity_error( + autojump_clock: MockClock, caplog: pytest.LogCaptureFixture +) -> None: listener = MemoryListener() - async def raise_EMFILE(): + async def raise_EMFILE() -> NoReturn: raise OSError(errno.EMFILE, "out of file descriptors") listener.accept_hook = raise_EMFILE @@ -116,19 +138,22 @@ async def raise_EMFILE(): assert len(caplog.records) == 10 for record in caplog.records: assert "retrying" in record.msg + assert isinstance(record.exc_info, ExceptionGroup) assert record.exc_info[1].errno == errno.EMFILE -async def test_serve_listeners_connection_nursery(autojump_clock) -> None: +async def test_serve_listeners_connection_nursery(autojump_clock: MockClock) -> None: listener = MemoryListener() - async def handler(stream) -> None: + async def handler(stream: Stream) -> None: await trio.sleep(1) class Done(Exception): pass - async def connection_watcher(*, task_status=trio.TASK_STATUS_IGNORED): + async def connection_watcher( + *, task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED + ) -> NoReturn: async with trio.open_nursery() as nursery: task_status.started(nursery) await wait_all_tasks_blocked() From 2edff671d242c372a9505a16baad582bd0b375b5 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 17 Oct 2023 15:39:11 +0200 Subject: [PATCH 07/35] type a bunch of files with few type errors --- pyproject.toml | 13 ------------- trio/_core/_tests/test_guest_mode.py | 3 ++- trio/_core/_tests/test_mock_clock.py | 10 ++++++---- .../test_multierror_scripts/simple_excepthook.py | 4 ++-- trio/_core/_tests/test_parking_lot.py | 4 ++-- trio/_core/_tests/test_thread_cache.py | 5 +++-- trio/_tests/test_highlevel_generic.py | 8 +++++--- trio/_tests/test_highlevel_open_unix_stream.py | 6 +++--- trio/_tests/test_scheduler_determinism.py | 10 ++++++++-- trio/_tests/test_ssl.py | 3 ++- trio/_tests/test_subprocess.py | 7 +++++-- trio/_tests/test_testing.py | 2 +- trio/_tests/test_threads.py | 5 +++-- trio/_tests/test_timeouts.py | 11 +++++++++-- trio/_tests/test_tracing.py | 2 +- trio/_tests/test_unix_pipes.py | 5 +++-- trio/_tests/test_wait_for_object.py | 4 ++-- 17 files changed, 57 insertions(+), 45 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index da22a6d271..6060954e38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,35 +84,22 @@ disallow_untyped_calls = false [[tool.mypy.overrides]] module = [ # tests -"trio/testing/_fake_net", "trio/_core/_tests/test_guest_mode", "trio/_core/_tests/test_ki", -"trio/_core/_tests/test_mock_clock", "trio/_core/_tests/test_multierror", -"trio/_core/_tests/test_multierror_scripts/ipython_custom_exc", -"trio/_core/_tests/test_multierror_scripts/simple_excepthook", -"trio/_core/_tests/test_parking_lot", "trio/_core/_tests/test_thread_cache", -"trio/_core/_tests/test_unbounded_queue", "trio/_tests/test_exports", "trio/_tests/test_file_io", -"trio/_tests/test_highlevel_generic", -"trio/_tests/test_highlevel_open_unix_stream", "trio/_tests/test_highlevel_ssl_helpers", -"trio/_tests/test_scheduler_determinism", "trio/_tests/test_ssl", "trio/_tests/test_subprocess", "trio/_tests/test_sync", "trio/_tests/test_testing", "trio/_tests/test_threads", -"trio/_tests/test_timeouts", -"trio/_tests/test_tracing", "trio/_tests/test_util", -"trio/_tests/test_wait_for_object", "trio/_tests/tools/test_gen_exports", ] disallow_any_decorated = false -disallow_any_unimported = false disallow_incomplete_defs = false disallow_untyped_defs = false diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py index bb3eb5e01a..ea0deae81e 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -16,6 +16,7 @@ import pytest from outcome import Outcome +from pytest import MonkeyPatch import trio import trio.testing @@ -490,7 +491,7 @@ async def aio_pingpong(from_trio, to_trio): ) -def test_guest_mode_internal_errors(monkeypatch, recwarn) -> None: +def test_guest_mode_internal_errors(monkeypatch: MonkeyPatch, recwarn) -> None: with monkeypatch.context() as m: async def crash_in_run_loop(in_host) -> None: diff --git a/trio/_core/_tests/test_mock_clock.py b/trio/_core/_tests/test_mock_clock.py index 4d655cd1a4..1a0c8b3444 100644 --- a/trio/_core/_tests/test_mock_clock.py +++ b/trio/_core/_tests/test_mock_clock.py @@ -55,7 +55,7 @@ def test_mock_clock() -> None: assert c2.current_time() < 10 -async def test_mock_clock_autojump(mock_clock) -> None: +async def test_mock_clock_autojump(mock_clock: MockClock) -> None: assert mock_clock.autojump_threshold == inf mock_clock.autojump_threshold = 0 @@ -95,7 +95,7 @@ async def test_mock_clock_autojump(mock_clock) -> None: await sleep(100000) -async def test_mock_clock_autojump_interference(mock_clock) -> None: +async def test_mock_clock_autojump_interference(mock_clock: MockClock) -> None: mock_clock.autojump_threshold = 0.02 mock_clock2 = MockClock() @@ -122,7 +122,9 @@ def test_mock_clock_autojump_preset() -> None: assert time.perf_counter() - real_start < 1 -async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock) -> None: +async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0( + mock_clock: MockClock, +) -> None: # Checks that autojump_threshold=0 doesn't interfere with # calling wait_all_tasks_blocked with the default cushion=0. @@ -149,7 +151,7 @@ async def waiter() -> None: @slow async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero( - mock_clock, + mock_clock: MockClock, ) -> None: # Checks that autojump_threshold=0 doesn't interfere with # calling wait_all_tasks_blocked with a non-zero cushion. diff --git a/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py b/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py index 65371107bc..236d34e9ba 100644 --- a/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py +++ b/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py @@ -3,14 +3,14 @@ from trio._core._multierror import MultiError # Bypass deprecation warnings -def exc1_fn(): +def exc1_fn() -> Exception: try: raise ValueError except Exception as exc: return exc -def exc2_fn(): +def exc2_fn() -> Exception: try: raise KeyError except Exception as exc: diff --git a/trio/_core/_tests/test_parking_lot.py b/trio/_core/_tests/test_parking_lot.py index 2946a2fb3c..74a4704bef 100644 --- a/trio/_core/_tests/test_parking_lot.py +++ b/trio/_core/_tests/test_parking_lot.py @@ -15,7 +15,7 @@ async def test_parking_lot_basic() -> None: record = [] - async def waiter(i, lot) -> None: + async def waiter(i: int, lot: ParkingLot) -> None: record.append(f"sleep {i}") await lot.park() record.append(f"wake {i}") @@ -83,7 +83,7 @@ async def waiter(i, lot) -> None: async def cancellable_waiter( - name: T, lot, scopes: dict[T, _core.CancelScope], record: list[str] + name: T, lot: ParkingLot, scopes: dict[T, _core.CancelScope], record: list[str] ) -> None: with _core.CancelScope() as scope: scopes[name] = scope diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py index 4893ac3cae..0241917b91 100644 --- a/trio/_core/_tests/test_thread_cache.py +++ b/trio/_core/_tests/test_thread_cache.py @@ -8,6 +8,7 @@ import pytest from outcome import Outcome +from pytest import MonkeyPatch from .. import _thread_cache from .._thread_cache import ThreadCache, start_thread_soon @@ -89,7 +90,7 @@ def deliver(n, _) -> None: @slow -def test_idle_threads_exit(monkeypatch) -> None: +def test_idle_threads_exit(monkeypatch: MonkeyPatch) -> None: # Temporarily set the idle timeout to something tiny, to speed up the # test. (But non-zero, so that the worker loop will at least yield the # CPU.) @@ -116,7 +117,7 @@ def _join_started_threads(): assert not thread.is_alive() -def test_race_between_idle_exit_and_job_assignment(monkeypatch) -> None: +def test_race_between_idle_exit_and_job_assignment(monkeypatch: MonkeyPatch) -> None: # This is a lock where the first few times you try to acquire it with a # timeout, it waits until the lock is available and then pretends to time # out. Using this in our thread cache implementation causes the following diff --git a/trio/_tests/test_highlevel_generic.py b/trio/_tests/test_highlevel_generic.py index 5ccc69151f..4b2008c08c 100644 --- a/trio/_tests/test_highlevel_generic.py +++ b/trio/_tests/test_highlevel_generic.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import NoReturn + import attr import pytest @@ -11,7 +13,7 @@ class RecordSendStream(SendStream): record: list[str | tuple[str, object]] = attr.ib(factory=list) - async def send_all(self, data) -> None: + async def send_all(self, data: object) -> None: self.record.append(("send_all", data)) async def wait_send_all_might_not_block(self) -> None: @@ -76,12 +78,12 @@ async def test_StapledStream_with_erroring_close() -> None: # Make sure that if one of the aclose methods errors out, then the other # one still gets called. class BrokenSendStream(RecordSendStream): - async def aclose(self): + async def aclose(self) -> NoReturn: await super().aclose() raise ValueError class BrokenReceiveStream(RecordReceiveStream): - async def aclose(self): + async def aclose(self) -> NoReturn: await super().aclose() raise ValueError diff --git a/trio/_tests/test_highlevel_open_unix_stream.py b/trio/_tests/test_highlevel_open_unix_stream.py index cf32b8c0fc..045820fccf 100644 --- a/trio/_tests/test_highlevel_open_unix_stream.py +++ b/trio/_tests/test_highlevel_open_unix_stream.py @@ -11,7 +11,7 @@ pytestmark = pytest.mark.skip("Needs unix socket support") -def test_close_on_error(): +def test_close_on_error() -> None: class CloseMe: closed = False @@ -29,9 +29,9 @@ def close(self) -> None: @pytest.mark.parametrize("filename", [4, 4.5]) -async def test_open_with_bad_filename_type(filename) -> None: +async def test_open_with_bad_filename_type(filename: float) -> None: with pytest.raises(TypeError): - await open_unix_socket(filename) + await open_unix_socket(filename) # type: ignore[arg-type] async def test_open_bad_socket() -> None: diff --git a/trio/_tests/test_scheduler_determinism.py b/trio/_tests/test_scheduler_determinism.py index 4c0da698ce..7e0a8e98de 100644 --- a/trio/_tests/test_scheduler_determinism.py +++ b/trio/_tests/test_scheduler_determinism.py @@ -1,7 +1,11 @@ +from __future__ import annotations + +from pytest import MonkeyPatch + import trio -async def scheduler_trace(): +async def scheduler_trace() -> tuple[tuple[str, int], ...]: """Returns a scheduler-dependent value we can use to check determinism.""" trace = [] @@ -25,7 +29,9 @@ def test_the_trio_scheduler_is_not_deterministic() -> None: assert len(set(traces)) == len(traces) -def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch) -> None: +def test_the_trio_scheduler_is_deterministic_if_seeded( + monkeypatch: MonkeyPatch, +) -> None: monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True) traces = [] for _ in range(10): diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py index 85d7f43069..53943cabce 100644 --- a/trio/_tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -10,6 +10,7 @@ import pytest +from trio._core import MockClock from trio._tests.pytest_plugin import skip_if_optional_else_raise try: @@ -579,7 +580,7 @@ async def test_renegotiation_simple(client_ctx) -> None: @slow -async def test_renegotiation_randomized(mock_clock, client_ctx) -> None: +async def test_renegotiation_randomized(mock_clock: MockClock, client_ctx) -> None: # The only blocking things in this function are our random sleeps, so 0 is # a good threshold. mock_clock.autojump_threshold = 0 diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index d6bf47ce82..3ebf2589bc 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, AsyncIterator import pytest +from pytest import MonkeyPatch from .. import ( ClosedResourceError, @@ -538,7 +539,7 @@ async def custom_deliver_cancel(proc) -> None: assert custom_deliver_cancel_called -async def test_warn_on_failed_cancel_terminate(monkeypatch) -> None: +async def test_warn_on_failed_cancel_terminate(monkeypatch: MonkeyPatch) -> None: original_terminate = Process.terminate def broken_terminate(self): @@ -555,7 +556,9 @@ def broken_terminate(self): @pytest.mark.skipif(not posix, reason="posix only") -async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch) -> None: +async def test_warn_on_cancel_SIGKILL_escalation( + autojump_clock, monkeypatch: MonkeyPatch +) -> None: monkeypatch.setattr(Process, "terminate", lambda *args: None) with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"): diff --git a/trio/_tests/test_testing.py b/trio/_tests/test_testing.py index 08f896b725..7461de058c 100644 --- a/trio/_tests/test_testing.py +++ b/trio/_tests/test_testing.py @@ -45,7 +45,7 @@ async def cancelled_while_waiting() -> None: assert record == ["ok"] -async def test_wait_all_tasks_blocked_with_timeouts(mock_clock) -> None: +async def test_wait_all_tasks_blocked_with_timeouts(mock_clock: MockClock) -> None: record = [] async def timeout_task() -> None: diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 860bc977c0..0fe5d8dc48 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -12,6 +12,7 @@ import pytest import sniffio +from pytest import MonkeyPatch from trio._core import TrioToken, current_trio_token @@ -359,7 +360,7 @@ async def child(q, cancellable): # Make sure that if trio.run exits, and then the thread finishes, then that's # handled gracefully. (Requires that the thread result machinery be prepared # for call_soon to raise RunFinishedError.) -def test_run_in_worker_thread_abandoned(capfd, monkeypatch) -> None: +def test_run_in_worker_thread_abandoned(capfd, monkeypatch: MonkeyPatch) -> None: monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01) q1 = stdlib_queue.Queue[None]() @@ -554,7 +555,7 @@ def release_on_behalf_of(self, borrower): assert record == ["acquire", "release"] -async def test_run_in_worker_thread_fail_to_spawn(monkeypatch) -> None: +async def test_run_in_worker_thread_fail_to_spawn(monkeypatch: MonkeyPatch) -> None: # Test the unlikely but possible case where trying to spawn a thread fails def bad_start(self, *args): raise RuntimeError("the engines canna take it captain") diff --git a/trio/_tests/test_timeouts.py b/trio/_tests/test_timeouts.py index 1491067c05..918e763faa 100644 --- a/trio/_tests/test_timeouts.py +++ b/trio/_tests/test_timeouts.py @@ -1,4 +1,5 @@ import time +from typing import Awaitable, Callable, TypeVar import outcome import pytest @@ -8,8 +9,12 @@ from .._timeouts import * from ..testing import assert_checkpoints +T = TypeVar("T") -async def check_takes_about(f, expected_dur): + +async def check_takes_about( + f: Callable[[], Awaitable[T]], expected_dur: float +) -> Awaitable[T]: start = time.perf_counter() result = await outcome.acapture(f) dur = time.perf_counter() - start @@ -34,7 +39,9 @@ async def check_takes_about(f, expected_dur): # value above is exactly 128 ULPs below 1.0, which would make sense if it # started as a 1 ULP error at a different dynamic range.) assert (1 - 1e-8) <= (dur / expected_dur) < 1.5 - return result.unwrap() + + # outcome is not typed + return result.unwrap() # type: ignore[no-any-return] # How long to (attempt to) sleep for when testing. Smaller numbers make the diff --git a/trio/_tests/test_tracing.py b/trio/_tests/test_tracing.py index 29fbb7a475..405831876c 100644 --- a/trio/_tests/test_tracing.py +++ b/trio/_tests/test_tracing.py @@ -16,7 +16,7 @@ async def coro3(event: trio.Event) -> None: await coro2(event) -async def coro2_async_gen(event) -> AsyncGenerator[None, None]: +async def coro2_async_gen(event: trio.Event) -> AsyncGenerator[None, None]: # mypy does not like `yield await trio.lowlevel.checkpoint()` - but that # should be equivalent to splitting the statement await trio.lowlevel.checkpoint() diff --git a/trio/_tests/test_unix_pipes.py b/trio/_tests/test_unix_pipes.py index f29ee241c6..5705a75faa 100644 --- a/trio/_tests/test_unix_pipes.py +++ b/trio/_tests/test_unix_pipes.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING import pytest +from pytest import MonkeyPatch from .. import _core from .._core._tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken @@ -182,7 +183,7 @@ async def expect_eof() -> None: os.close(w2_fd) -async def test_close_at_bad_time_for_receive_some(monkeypatch) -> None: +async def test_close_at_bad_time_for_receive_some(monkeypatch: MonkeyPatch) -> None: # We used to have race conditions where if one task was using the pipe, # and another closed it at *just* the wrong moment, it would give an # unexpected error instead of ClosedResourceError: @@ -210,7 +211,7 @@ async def patched_wait_readable(*args, **kwargs) -> None: await s.send_all(b"x") -async def test_close_at_bad_time_for_send_all(monkeypatch) -> None: +async def test_close_at_bad_time_for_send_all(monkeypatch: MonkeyPatch) -> None: # We used to have race conditions where if one task was using the pipe, # and another closed it at *just* the wrong moment, it would give an # unexpected error instead of ClosedResourceError: diff --git a/trio/_tests/test_wait_for_object.py b/trio/_tests/test_wait_for_object.py index 53e771b7ed..44790497ed 100644 --- a/trio/_tests/test_wait_for_object.py +++ b/trio/_tests/test_wait_for_object.py @@ -12,7 +12,7 @@ from .._core._tests.tutil import slow if on_windows: - from .._core._windows_cffi import ffi, kernel32 + from .._core._windows_cffi import Handle, ffi, kernel32 from .._wait_for_object import WaitForMultipleObjects_sync, WaitForSingleObject @@ -168,7 +168,7 @@ async def test_WaitForSingleObject_slow() -> None: # the timeout with a certain margin. TIMEOUT = 0.3 - async def signal_soon_async(handle) -> None: + async def signal_soon_async(handle: Handle) -> None: await _timeouts.sleep(TIMEOUT) kernel32.SetEvent(handle) From 718baa0fd69c89d5251707f1dfc06c99b0a48c3f Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 18 Oct 2023 14:53:55 +0200 Subject: [PATCH 08/35] more typing --- pyproject.toml | 8 - trio/_core/_tests/test_guest_mode.py | 92 ++++++---- trio/_core/_tests/test_thread_cache.py | 16 +- trio/_ssl.py | 11 +- trio/_tests/test_exports.py | 22 ++- trio/_tests/test_highlevel_serve_listeners.py | 7 +- trio/_tests/test_highlevel_ssl_helpers.py | 32 +++- trio/_tests/test_ssl.py | 168 +++++++++++------- trio/_tests/test_subprocess.py | 6 +- trio/_tests/test_sync.py | 35 +++- trio/_tests/test_testing.py | 51 +++--- trio/_tests/test_threads.py | 4 +- trio/_tests/test_util.py | 15 +- trio/_tests/test_wait_for_object.py | 2 +- trio/_tests/tools/test_gen_exports.py | 9 +- 15 files changed, 297 insertions(+), 181 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6060954e38..831d577e65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,17 +87,9 @@ module = [ "trio/_core/_tests/test_guest_mode", "trio/_core/_tests/test_ki", "trio/_core/_tests/test_multierror", -"trio/_core/_tests/test_thread_cache", -"trio/_tests/test_exports", "trio/_tests/test_file_io", -"trio/_tests/test_highlevel_ssl_helpers", -"trio/_tests/test_ssl", "trio/_tests/test_subprocess", -"trio/_tests/test_sync", -"trio/_tests/test_testing", "trio/_tests/test_threads", -"trio/_tests/test_util", -"trio/_tests/tools/test_gen_exports", ] disallow_any_decorated = false disallow_incomplete_defs = false diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py index ea0deae81e..286e8657a1 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -12,18 +12,23 @@ import warnings from functools import partial from math import inf -from typing import Callable +from typing import TYPE_CHECKING, Any, Awaitable, Callable, NoReturn, TypeVar import pytest from outcome import Outcome -from pytest import MonkeyPatch +from pytest import MonkeyPatch, WarningsRecorder +from typing_extensions import TypeAlias import trio import trio.testing +from trio._channel import MemorySendChannel from ..._util import signal_raise from .tutil import buggy_pypy_asyncgens, gc_collect_harder, restore_unraisablehook +T = TypeVar("T") +InHost: TypeAlias = Callable[[object], None] + # The simplest possible "host" loop. # Nice features: @@ -31,7 +36,12 @@ # our main # - final result is returned # - any unhandled exceptions cause an immediate crash -def trivial_guest_run(trio_fn, *, in_host_after_start=None, **start_guest_run_kwargs): +def trivial_guest_run( + trio_fn: Callable[..., Awaitable[T]], + *, + in_host_after_start: Callable[[], None] | None = None, + **start_guest_run_kwargs: Any, +) -> T: todo: queue.Queue[tuple[str, Outcome]] = queue.Queue() host_thread = threading.current_thread() @@ -75,7 +85,7 @@ def done_callback(outcome: Outcome) -> None: if op == "run": obj() elif op == "unwrap": - return obj.unwrap() + return obj.unwrap() # type: ignore[no-any-return] else: # pragma: no cover assert False finally: @@ -86,13 +96,13 @@ def done_callback(outcome: Outcome) -> None: def test_guest_trivial() -> None: - async def trio_return(in_host) -> str: + async def trio_return(in_host: InHost) -> str: await trio.sleep(0) return "ok" assert trivial_guest_run(trio_return) == "ok" - async def trio_fail(in_host): + async def trio_fail(in_host: InHost) -> NoReturn: raise KeyError("whoopsiedaisy") with pytest.raises(KeyError, match="whoopsiedaisy"): @@ -100,7 +110,7 @@ async def trio_fail(in_host): def test_guest_can_do_io() -> None: - async def trio_main(in_host) -> None: + async def trio_main(in_host: InHost) -> None: record = [] a, b = trio.socket.socketpair() with a, b: @@ -123,7 +133,7 @@ def test_guest_is_initialized_when_start_returns() -> None: trio_token = None record = [] - async def trio_main(in_host) -> str: + async def trio_main(in_host: InHost) -> str: record.append("main task ran") await trio.sleep(0) assert trio.lowlevel.current_trio_token() is trio_token @@ -151,7 +161,7 @@ async def early_task() -> None: with pytest.raises(trio.TrioInternalError): class BadClock: - def start_clock(self): + def start_clock(self) -> NoReturn: raise ValueError("whoops") def after_start_never_runs() -> None: # pragma: no cover @@ -163,7 +173,7 @@ def after_start_never_runs() -> None: # pragma: no cover def test_host_can_directly_wake_trio_task() -> None: - async def trio_main(in_host) -> str: + async def trio_main(in_host: InHost) -> str: ev = trio.Event() in_host(ev.set) await ev.wait() @@ -173,10 +183,10 @@ async def trio_main(in_host) -> str: def test_host_altering_deadlines_wakes_trio_up() -> None: - def set_deadline(cscope, new_deadline) -> None: + def set_deadline(cscope: trio.CancelScope, new_deadline: float) -> None: cscope.deadline = new_deadline - async def trio_main(in_host) -> str: + async def trio_main(in_host: InHost) -> str: with trio.CancelScope() as cscope: in_host(lambda: set_deadline(cscope, -inf)) await trio.sleep_forever() @@ -198,7 +208,7 @@ async def trio_main(in_host) -> str: def test_guest_mode_sniffio_integration() -> None: from sniffio import current_async_library, thread_local as sniffio_library - async def trio_main(in_host) -> str: + async def trio_main(in_host: InHost) -> str: async def synchronize() -> None: """Wait for all in_host() calls issued so far to complete.""" evt = trio.Event() @@ -227,7 +237,7 @@ async def synchronize() -> None: def test_warn_set_wakeup_fd_overwrite() -> None: assert signal.set_wakeup_fd(-1) == -1 - async def trio_main(in_host) -> str: + async def trio_main(in_host: InHost) -> str: return "ok" a, b = socket.socketpair() @@ -269,7 +279,7 @@ async def trio_main(in_host) -> str: signal.set_wakeup_fd(a.fileno()) try: - async def trio_check_wakeup_fd_unaltered(in_host) -> str: + async def trio_check_wakeup_fd_unaltered(in_host: InHost) -> str: fd = signal.set_wakeup_fd(-1) assert fd == a.fileno() signal.set_wakeup_fd(fd) @@ -295,12 +305,12 @@ def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked() -> None: # events is Truth-y # ...and confirm that in this case, wait_all_tasks_blocked does not get # triggered. - def set_deadline(cscope, new_deadline) -> None: + def set_deadline(cscope: trio.CancelScope, new_deadline: float) -> None: print(f"setting deadline {new_deadline}") cscope.deadline = new_deadline - async def trio_main(in_host) -> str: - async def sit_in_wait_all_tasks_blocked(watb_cscope) -> None: + async def trio_main(in_host: InHost) -> str: + async def sit_in_wait_all_tasks_blocked(watb_cscope: trio.CancelScope) -> None: with watb_cscope: # Overall point of this test is that this # wait_all_tasks_blocked should *not* return normally, but @@ -309,7 +319,7 @@ async def sit_in_wait_all_tasks_blocked(watb_cscope) -> None: assert False # pragma: no cover assert watb_cscope.cancelled_caught - async def get_woken_by_host_deadline(watb_cscope) -> None: + async def get_woken_by_host_deadline(watb_cscope: trio.CancelScope) -> None: with trio.CancelScope() as cscope: print("scheduling stuff to happen") @@ -367,7 +377,7 @@ def test_guest_warns_if_abandoned() -> None: # put it into a function, so that we're sure all the local state, # traceback frames, etc. are garbage once it returns. def do_abandoned_guest_run() -> None: - async def abandoned_main(in_host) -> None: + async def abandoned_main(in_host: InHost) -> None: in_host(lambda: 1 / 0) while True: await trio.sleep(0) @@ -402,13 +412,18 @@ async def abandoned_main(in_host) -> None: trio.current_time() -def aiotrio_run(trio_fn, *, pass_not_threadsafe: bool = True, **start_guest_run_kwargs): +def aiotrio_run( + trio_fn: Callable[..., Awaitable[T]], + *, + pass_not_threadsafe: bool = True, + **start_guest_run_kwargs: Any, +) -> T: loop = asyncio.new_event_loop() - async def aio_main(): + async def aio_main() -> T: trio_done_fut = loop.create_future() - def trio_done_callback(main_outcome) -> None: + def trio_done_callback(main_outcome: Outcome) -> None: print(f"trio_fn finished: {main_outcome!r}") trio_done_fut.set_result(main_outcome) @@ -422,7 +437,7 @@ def trio_done_callback(main_outcome) -> None: **start_guest_run_kwargs, ) - return (await trio_done_fut).unwrap() + return (await trio_done_fut).unwrap() # type: ignore[no-any-return] try: return loop.run_until_complete(aio_main()) @@ -454,7 +469,9 @@ async def trio_main() -> str: raise AssertionError("should never be reached") - async def aio_pingpong(from_trio, to_trio): + async def aio_pingpong( + from_trio: asyncio.Queue[int], to_trio: MemorySendChannel[int] + ) -> None: print("aio_pingpong!") try: @@ -491,10 +508,12 @@ async def aio_pingpong(from_trio, to_trio): ) -def test_guest_mode_internal_errors(monkeypatch: MonkeyPatch, recwarn) -> None: +def test_guest_mode_internal_errors( + monkeypatch: MonkeyPatch, recwarn: WarningsRecorder +) -> None: with monkeypatch.context() as m: - async def crash_in_run_loop(in_host) -> None: + async def crash_in_run_loop(in_host: InHost) -> None: m.setattr("trio._core._run.GLOBAL_RUN_CONTEXT.runner.runq", "HI") await trio.sleep(1) @@ -503,7 +522,7 @@ async def crash_in_run_loop(in_host) -> None: with monkeypatch.context() as m: - async def crash_in_io(in_host) -> None: + async def crash_in_io(in_host: InHost) -> None: m.setattr("trio._core._run.TheIOManager.get_events", None) await trio.sleep(0) @@ -512,11 +531,11 @@ async def crash_in_io(in_host) -> None: with monkeypatch.context() as m: - async def crash_in_worker_thread_io(in_host) -> None: + async def crash_in_worker_thread_io(in_host: InHost) -> None: t = threading.current_thread() old_get_events = trio._core._run.TheIOManager.get_events - def bad_get_events(*args): + def bad_get_events(*args: Any) -> object: if threading.current_thread() is not t: raise ValueError("oh no!") else: @@ -536,7 +555,7 @@ def test_guest_mode_ki() -> None: assert signal.getsignal(signal.SIGINT) is signal.default_int_handler # Check SIGINT in Trio func and in host func - async def trio_main(in_host) -> None: + async def trio_main(in_host: InHost) -> None: with pytest.raises(KeyboardInterrupt): signal_raise(signal.SIGINT) @@ -553,7 +572,7 @@ async def trio_main(in_host) -> None: # Also check chaining in the case where KI is injected after main exits final_exc = KeyError("whoa") - async def trio_main_raising(in_host): + async def trio_main_raising(in_host: InHost) -> NoReturn: in_host(partial(signal_raise, signal.SIGINT)) raise final_exc @@ -573,7 +592,7 @@ def test_guest_mode_autojump_clock_threshold_changing() -> None: DURATION = 120 - async def trio_main(in_host) -> None: + async def trio_main(in_host: InHost) -> None: assert trio.current_time() == 0 in_host(lambda: setattr(clock, "autojump_threshold", 0)) await trio.sleep(DURATION) @@ -594,7 +613,7 @@ def test_guest_mode_asyncgens() -> None: record = set() - async def agen(label: str): + async def agen(label: str): # TODO: some asyncgenerator thing assert sniffio.current_async_library() == label try: yield 1 @@ -623,6 +642,9 @@ async def trio_main() -> None: # Ensure we don't pollute the thread-level context if run under # an asyncio without contextvars support (3.6) context = contextvars.copy_context() - context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True) + if TYPE_CHECKING: + aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True) + # this type error is a bug in typeshed or mypy, as it's equivalent to the above line + context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True) # type: ignore[arg-type] assert record == {("asyncio", "asyncio"), ("trio", "trio")} diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py index 0241917b91..2b8d913948 100644 --- a/trio/_core/_tests/test_thread_cache.py +++ b/trio/_core/_tests/test_thread_cache.py @@ -4,7 +4,7 @@ import time from contextlib import contextmanager from queue import Queue -from typing import NoReturn +from typing import Iterator, NoReturn import pytest from outcome import Outcome @@ -21,7 +21,7 @@ def test_thread_cache_basics() -> None: def fn() -> NoReturn: raise RuntimeError("hi") - def deliver(outcome) -> None: + def deliver(outcome: Outcome) -> None: q.put(outcome) start_thread_soon(fn, deliver) @@ -43,7 +43,7 @@ def __del__(self) -> None: q = Queue[Outcome]() - def deliver(outcome) -> None: + def deliver(outcome: Outcome) -> None: q.put(outcome) start_thread_soon(del_me(), deliver) @@ -74,7 +74,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None: seen_threads = set() done = threading.Event() - def deliver(n, _) -> None: + def deliver(n: int, _: object) -> None: print(n) seen_threads.add(threading.current_thread()) if n == 0: @@ -106,7 +106,7 @@ def test_idle_threads_exit(monkeypatch: MonkeyPatch) -> None: @contextmanager -def _join_started_threads(): +def _join_started_threads() -> Iterator[None]: before = frozenset(threading.enumerate()) try: yield @@ -140,7 +140,7 @@ def __init__(self) -> None: self._lock = threading.Lock() self._counter = 3 - def acquire(self, timeout=-1) -> bool: + def acquire(self, timeout: int = -1) -> bool: got_it = self._lock.acquire(timeout=timeout) if timeout == -1: return True @@ -170,13 +170,13 @@ def release(self) -> None: tc.start_thread_soon(lambda: None, lambda _: None) -def test_raise_in_deliver(capfd) -> None: +def test_raise_in_deliver(capfd: pytest.CaptureFixture[str]) -> None: seen_threads = set() def track_threads() -> None: seen_threads.add(threading.current_thread()) - def deliver(_): + def deliver(_: object) -> NoReturn: done.set() raise RuntimeError("don't do this") diff --git a/trio/_ssl.py b/trio/_ssl.py index 21c4deabb3..073d7e3812 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -4,7 +4,7 @@ import ssl as _stdlib_ssl from collections.abc import Awaitable, Callable from enum import Enum as _Enum -from typing import Any, ClassVar, Final as TFinal, TypeVar +from typing import Any, ClassVar, Final as TFinal, Generic, TypeVar import trio @@ -239,9 +239,12 @@ def done(self) -> bool: _State = _Enum("_State", ["OK", "BROKEN", "CLOSED"]) +# TODO: variance +T_Stream = TypeVar("T_Stream", bound=Stream) + @final -class SSLStream(Stream): +class SSLStream(Stream, Generic[T_Stream]): r"""Encrypted communication using SSL/TLS. :class:`SSLStream` wraps an arbitrary :class:`~trio.abc.Stream`, and @@ -339,14 +342,14 @@ class SSLStream(Stream): # SSLListener.__init__, and maybe the open_ssl_over_tcp_* helpers. def __init__( self, - transport_stream: Stream, + transport_stream: T_Stream, ssl_context: _stdlib_ssl.SSLContext, *, server_hostname: str | bytes | None = None, server_side: bool = False, https_compatible: bool = False, ) -> None: - self.transport_stream: Stream = transport_stream + self.transport_stream: T_Stream = transport_stream self._state = _State.OK self._https_compatible = https_compatible self._outgoing = _stdlib_ssl.MemoryBIO() diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py index 7855aede6f..bcfc450362 100644 --- a/trio/_tests/test_exports.py +++ b/trio/_tests/test_exports.py @@ -12,7 +12,7 @@ from collections.abc import Iterator from pathlib import Path from types import ModuleType -from typing import Protocol +from typing import Iterable, Protocol import attrs import pytest @@ -126,10 +126,10 @@ def iter_modules( # https://github.com/pypa/setuptools/issues/3274 "ignore:module 'sre_constants' is deprecated:DeprecationWarning", ) -def test_static_tool_sees_all_symbols(tool, modname: str, tmp_path) -> None: +def test_static_tool_sees_all_symbols(tool: str, modname: str, tmp_path: Path) -> None: module = importlib.import_module(modname) - def no_underscores(symbols): + def no_underscores(symbols: Iterable[str]) -> set[str]: return {symbol for symbol in symbols if not symbol.startswith("_")} runtime_names = no_underscores(dir(module)) @@ -278,7 +278,7 @@ def test_static_tool_sees_class_members( module = PUBLIC_MODULES[PUBLIC_MODULE_NAMES.index(module_name)] # ignore hidden, but not dunder, symbols - def no_hidden(symbols): + def no_hidden(symbols: Iterable[str]) -> set[str]: return { symbol for symbol in symbols @@ -316,7 +316,7 @@ def no_hidden(symbols): # skip a bunch of file-system activity (probably can un-memoize?) @functools.lru_cache - def lookup_symbol(symbol): + def lookup_symbol(symbol: str) -> dict[str, str]: topname, *modname, name = symbol.split(".") version = next(cache.glob("3.*/")) mod_cache = version / topname @@ -333,7 +333,7 @@ def lookup_symbol(symbol): mod_cache = mod_cache / (modname[-1] + ".data.json") with mod_cache.open() as f: - return json.loads(f.read())["names"][name] + return json.loads(f.read())["names"][name] # type: ignore[no-any-return] errors: dict[str, object] = {} for class_name, class_ in module.__dict__.items(): @@ -441,6 +441,16 @@ def lookup_symbol(symbol): extra = {e for e in extra if not e.endswith("AttrsAttributes__")} assert len(extra) == before - 1 + import enum + + # mypy does not see these attributes in Enum subclasses + if ( + tool == "mypy" + and enum.Enum in class_.__mro__ + and sys.version_info >= (3, 11) + ): + extra.difference_update({"__copy__", "__deepcopy__"}) + # TODO: this *should* be visible via `dir`!! if tool == "mypy" and class_ == trio.Nursery: extra.remove("cancel_scope") diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py index 2e333068ec..4a267222a0 100644 --- a/trio/_tests/test_highlevel_serve_listeners.py +++ b/trio/_tests/test_highlevel_serve_listeners.py @@ -1,7 +1,6 @@ from __future__ import annotations import errno -import sys from functools import partial from typing import Awaitable, Callable, NoReturn @@ -20,9 +19,6 @@ wait_all_tasks_blocked, ) -if sys.version_info < (3, 11): - from exceptiongroup import ExceptionGroup - # types are somewhat tentative - I just bruteforced them until I got something that didn't # give errors TypeThing = StapledStream[MemorySendStream, MemoryReceiveStream] @@ -138,7 +134,8 @@ async def raise_EMFILE() -> NoReturn: assert len(caplog.records) == 10 for record in caplog.records: assert "retrying" in record.msg - assert isinstance(record.exc_info, ExceptionGroup) + assert record.exc_info is not None + assert isinstance(record.exc_info[1], OSError) assert record.exc_info[1].errno == errno.EMFILE diff --git a/trio/_tests/test_highlevel_ssl_helpers.py b/trio/_tests/test_highlevel_ssl_helpers.py index afb5f30a6b..5a06279204 100644 --- a/trio/_tests/test_highlevel_ssl_helpers.py +++ b/trio/_tests/test_highlevel_ssl_helpers.py @@ -1,12 +1,16 @@ from __future__ import annotations from functools import partial +from socket import AddressFamily, SocketKind +from ssl import SSLContext +from typing import Any, NoReturn import attr import pytest import trio import trio.testing +from trio.abc import Stream from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM from .._highlevel_socket import SocketListener @@ -21,7 +25,7 @@ from .test_ssl import SERVER_CTX, client_ctx # noqa: F401 -async def echo_handler(stream) -> None: +async def echo_handler(stream: Stream) -> None: async with stream: try: while True: @@ -37,19 +41,35 @@ async def echo_handler(stream) -> None: # you ask for. @attr.s class FakeHostnameResolver(trio.abc.HostnameResolver): - sockaddr = attr.ib() - - async def getaddrinfo(self, *args): + sockaddr: tuple[str, int] | tuple[str, int, int, int] = attr.ib() + + async def getaddrinfo( + self, + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ]: return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)] - async def getnameinfo(self, *args): # pragma: no cover + async def getnameinfo(self, *args: Any) -> NoReturn: # pragma: no cover raise NotImplementedError # This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners... # using noqa because linters don't understand how pytest fixtures work. async def test_open_ssl_over_tcp_stream_and_everything_else( - client_ctx, # noqa: F811 # linters doesn't understand fixture + client_ctx: SSLContext, # noqa: F811 # linters doesn't understand fixture ) -> None: async with trio.open_nursery() as nursery: # TODO: the types are *very* funky here, this seems like an error in some signature diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py index 53943cabce..5dece46e54 100644 --- a/trio/_tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -7,11 +7,18 @@ import threading from contextlib import asynccontextmanager, contextmanager from functools import partial +from ssl import SSLContext +from typing import Any, AsyncIterator, Iterator, NoReturn import pytest +from typing_extensions import TypeAlias +from trio import StapledStream from trio._core import MockClock +from trio._ssl import T_Stream from trio._tests.pytest_plugin import skip_if_optional_else_raise +from trio.abc import ReceiveStream, SendStream +from trio.testing import MemoryReceiveStream, MemorySendStream try: import trustme @@ -24,6 +31,7 @@ from .. import _core, socket as tsocket from .._abc import Stream from .._core import BrokenResourceError, ClosedResourceError +from .._core._run import CancelScope from .._core._tests.tutil import slow from .._highlevel_generic import aclose_forcefully from .._highlevel_open_tcp_stream import open_tcp_stream @@ -73,7 +81,7 @@ # downgrade on the server side. "tls12" means we refuse to negotiate TLS # 1.3, so we'll almost certainly use TLS 1.2. @pytest.fixture(scope="module", params=["tls13", "tls12"]) -def client_ctx(request): +def client_ctx(request: pytest.FixtureRequest) -> ssl.SSLContext: ctx = ssl.create_default_context() if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): @@ -90,7 +98,9 @@ def client_ctx(request): # The blocking socket server. -def ssl_echo_serve_sync(sock, *, expect_fail: bool = False): +def ssl_echo_serve_sync( + sock: stdlib_socket.socket, *, expect_fail: bool = False +) -> None: try: wrapped = SERVER_CTX.wrap_socket( sock, server_side=True, suppress_ragged_eofs=False @@ -142,8 +152,8 @@ def ssl_echo_serve_sync(sock, *, expect_fail: bool = False): # Fixture that gives a raw socket connected to a trio-test-1 echo server # (running in a thread). Useful for testing making connections with different # SSLContexts. -@asynccontextmanager -async def ssl_echo_server_raw(**kwargs): +@asynccontextmanager # type: ignore[misc] # decorated contains Any +async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]: a, b = stdlib_socket.socketpair() async with trio.open_nursery() as nursery: # Exiting the 'with a, b' context manager closes the sockets, which @@ -159,8 +169,10 @@ async def ssl_echo_server_raw(**kwargs): # Fixture that gives a properly set up SSLStream connected to a trio-test-1 # echo server (running in a thread) -@asynccontextmanager -async def ssl_echo_server(client_ctx, **kwargs): +@asynccontextmanager # type: ignore[misc] # decorated contains Any +async def ssl_echo_server( + client_ctx: SSLContext, **kwargs: Any +) -> AsyncIterator[SSLStream[Stream]]: async with ssl_echo_server_raw(**kwargs) as sock: yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org") @@ -171,7 +183,7 @@ async def ssl_echo_server(client_ctx, **kwargs): # jakkdl: it seems to implement all the abstract methods (now), so I made it inherit # from Stream for the sake of typechecking. class PyOpenSSLEchoStream(Stream): - def __init__(self, sleeper=None) -> None: + def __init__(self, sleeper: None = None) -> None: ctx = SSL.Context(SSL.SSLv23_METHOD) # TLS 1.3 removes renegotiation support. Which is great for them, but # we still have to support versions before that, and that means we @@ -221,7 +233,7 @@ def __init__(self, sleeper=None) -> None: if sleeper is None: - async def no_op_sleeper(_) -> None: + async def no_op_sleeper(_: object) -> None: return self.sleeper = no_op_sleeper @@ -231,7 +243,7 @@ async def no_op_sleeper(_) -> None: async def aclose(self) -> None: self._conn.bio_shutdown() - def renegotiate_pending(self): + def renegotiate_pending(self) -> bool: return self._conn.renegotiate_pending() def renegotiate(self) -> None: @@ -245,7 +257,7 @@ async def wait_send_all_might_not_block(self) -> None: await _core.checkpoint() await self.sleeper("wait_send_all_might_not_block") - async def send_all(self, data) -> None: + async def send_all(self, data: bytes) -> None: print(" --> transport_stream.send_all") with self._send_all_conflict_detector: await _core.checkpoint() @@ -268,7 +280,7 @@ async def send_all(self, data) -> None: await self.sleeper("send_all") print(" <-- transport_stream.send_all finished") - async def receive_some(self, nbytes=None): + async def receive_some(self, nbytes: int | None = None) -> bytes: print(" --> transport_stream.receive_some") if nbytes is None: nbytes = 65536 # arbitrary @@ -360,20 +372,22 @@ async def test_PyOpenSSLEchoStream_gives_resource_busy_errors() -> None: assert "simultaneous" in str(excinfo.value) -@contextmanager -def virtual_ssl_echo_server(client_ctx, **kwargs): +@contextmanager # type: ignore[misc] # decorated contains Any +def virtual_ssl_echo_server( + client_ctx: SSLContext, **kwargs: Any +) -> Iterator[SSLStream[PyOpenSSLEchoStream]]: fakesock = PyOpenSSLEchoStream(**kwargs) yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org") def ssl_wrap_pair( - client_ctx, - client_transport, - server_transport, + client_ctx: SSLContext, + client_transport: T_Stream, + server_transport: T_Stream, *, - client_kwargs={}, - server_kwargs={}, -): + client_kwargs: dict[str, Any] = {}, + server_kwargs: dict[str, Any] = {}, +) -> tuple[SSLStream[T_Stream], SSLStream[T_Stream]]: client_ssl = SSLStream( client_transport, client_ctx, @@ -386,12 +400,22 @@ def ssl_wrap_pair( return client_ssl, server_ssl -def ssl_memory_stream_pair(client_ctx, **kwargs): +MemoryStapledStream: TypeAlias = StapledStream[MemorySendStream, MemoryReceiveStream] + + +def ssl_memory_stream_pair( + client_ctx: SSLContext, **kwargs: Any +) -> tuple[SSLStream[MemoryStapledStream], SSLStream[MemoryStapledStream],]: client_transport, server_transport = memory_stream_pair() return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) -def ssl_lockstep_stream_pair(client_ctx, **kwargs): +MyStapledStream: TypeAlias = StapledStream[SendStream, ReceiveStream] + + +def ssl_lockstep_stream_pair( + client_ctx: SSLContext, **kwargs: Any +) -> tuple[SSLStream[MyStapledStream], SSLStream[MyStapledStream],]: client_transport, server_transport = lockstep_stream_pair() return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) @@ -399,7 +423,7 @@ def ssl_lockstep_stream_pair(client_ctx, **kwargs): # Simple smoke test for handshake/send/receive/shutdown talking to a # synchronous server, plus make sure that we do the bare minimum of # certificate checking (even though this is really Python's responsibility) -async def test_ssl_client_basics(client_ctx) -> None: +async def test_ssl_client_basics(client_ctx: SSLContext) -> None: # Everything OK async with ssl_echo_server(client_ctx) as s: assert not s.server_side @@ -425,7 +449,7 @@ async def test_ssl_client_basics(client_ctx) -> None: assert isinstance(excinfo.value.__cause__, ssl.CertificateError) -async def test_ssl_server_basics(client_ctx) -> None: +async def test_ssl_server_basics(client_ctx: SSLContext) -> None: a, b = stdlib_socket.socketpair() with a, b: server_sock = tsocket.from_stdlib_socket(b) @@ -455,7 +479,7 @@ def client() -> None: t.join() -async def test_attributes(client_ctx) -> None: +async def test_attributes(client_ctx: SSLContext) -> None: async with ssl_echo_server_raw(expect_fail=True) as sock: good_ctx = client_ctx bad_ctx = ssl.create_default_context() @@ -524,7 +548,7 @@ async def test_attributes(client_ctx) -> None: # I begin to see why HTTP/2 forbids renegotiation and TLS 1.3 removes it... -async def test_full_duplex_basics(client_ctx) -> None: +async def test_full_duplex_basics(client_ctx: SSLContext) -> None: CHUNKS = 30 CHUNK_SIZE = 32768 EXPECTED = CHUNKS * CHUNK_SIZE @@ -532,7 +556,7 @@ async def test_full_duplex_basics(client_ctx) -> None: sent = bytearray() received = bytearray() - async def sender(s) -> None: + async def sender(s: Stream) -> None: nonlocal sent for i in range(CHUNKS): print(i) @@ -540,7 +564,7 @@ async def sender(s) -> None: sent += chunk await s.send_all(chunk) - async def receiver(s) -> None: + async def receiver(s: Stream) -> None: nonlocal received while len(received) < EXPECTED: chunk = await s.receive_some(CHUNK_SIZE // 2) @@ -561,10 +585,9 @@ async def receiver(s) -> None: assert sent == received -async def test_renegotiation_simple(client_ctx) -> None: +async def test_renegotiation_simple(client_ctx: SSLContext) -> None: with virtual_ssl_echo_server(client_ctx) as s: await s.do_handshake() - s.transport_stream.renegotiate() await s.send_all(b"a") assert await s.receive_some(1) == b"a" @@ -580,7 +603,9 @@ async def test_renegotiation_simple(client_ctx) -> None: @slow -async def test_renegotiation_randomized(mock_clock: MockClock, client_ctx) -> None: +async def test_renegotiation_randomized( + mock_clock: MockClock, client_ctx: SSLContext +) -> None: # The only blocking things in this function are our random sleeps, so 0 is # a good threshold. mock_clock.autojump_threshold = 0 @@ -589,7 +614,7 @@ async def test_renegotiation_randomized(mock_clock: MockClock, client_ctx) -> No r = random.Random(0) - async def sleeper(_) -> None: + async def sleeper(_: object) -> None: await trio.sleep(r.uniform(0, 10)) async def clear() -> None: @@ -600,13 +625,13 @@ async def clear() -> None: await expect(b"-") print("-- clear --") - async def send(byte) -> None: + async def send(byte: bytes) -> None: await s.transport_stream.sleeper("outer send") print("calling SSLStream.send_all", byte) with assert_checkpoints(): await s.send_all(byte) - async def expect(expected) -> None: + async def expect(expected: bytes) -> None: await s.transport_stream.sleeper("expect") print("calling SSLStream.receive_some, expecting", expected) assert len(expected) == 1 @@ -652,7 +677,7 @@ async def expect(expected) -> None: # and wait_send_all_might_not_block comes in. # Our receive_some() call will get stuck when it hits send_all - async def sleeper_with_slow_send_all(method) -> None: + async def sleeper_with_slow_send_all(method: str) -> None: if method == "send_all": await trio.sleep(100000) @@ -676,7 +701,7 @@ async def sleep_then_wait_writable() -> None: # 2) Same, but now wait_send_all_might_not_block is stuck when # receive_some tries to send. - async def sleeper_with_slow_wait_writable_and_expect(method) -> None: + async def sleeper_with_slow_wait_writable_and_expect(method: str) -> None: if method == "wait_send_all_might_not_block": await trio.sleep(100000) elif method == "expect": @@ -696,7 +721,7 @@ async def sleeper_with_slow_wait_writable_and_expect(method) -> None: await s.aclose() -async def test_resource_busy_errors(client_ctx) -> None: +async def test_resource_busy_errors(client_ctx: SSLContext) -> None: async def do_send_all() -> None: with assert_checkpoints(): await s.send_all(b"x") @@ -765,7 +790,7 @@ async def send_all(self, data: bytes | bytearray | memoryview) -> None: os.name == "nt" and sys.version_info >= (3, 10), reason="frequently fails on Windows + Python 3.10", ) -async def test_checkpoints(client_ctx) -> None: +async def test_checkpoints(client_ctx: SSLContext) -> None: async with ssl_echo_server(client_ctx) as s: with assert_checkpoints(): await s.do_handshake() @@ -794,7 +819,7 @@ async def test_checkpoints(client_ctx) -> None: await s.aclose() -async def test_send_all_empty_string(client_ctx) -> None: +async def test_send_all_empty_string(client_ctx: SSLContext) -> None: async with ssl_echo_server(client_ctx) as s: await s.do_handshake() @@ -811,15 +836,23 @@ async def test_send_all_empty_string(client_ctx) -> None: @pytest.mark.parametrize("https_compatible", [False, True]) -async def test_SSLStream_generic(client_ctx, https_compatible) -> None: - async def stream_maker(): +async def test_SSLStream_generic( + client_ctx: SSLContext, https_compatible: bool +) -> None: + async def stream_maker() -> tuple[ + SSLStream[MemoryStapledStream], + SSLStream[MemoryStapledStream], + ]: return ssl_memory_stream_pair( client_ctx, client_kwargs={"https_compatible": https_compatible}, server_kwargs={"https_compatible": https_compatible}, ) - async def clogged_stream_maker(): + async def clogged_stream_maker() -> tuple[ + SSLStream[MyStapledStream], + SSLStream[MyStapledStream], + ]: client, server = ssl_lockstep_stream_pair(client_ctx) # If we don't do handshakes up front, then we run into a problem in # the following situation: @@ -835,7 +868,7 @@ async def clogged_stream_maker(): await check_two_way_stream(stream_maker, clogged_stream_maker) -async def test_unwrap(client_ctx) -> None: +async def test_unwrap(client_ctx: SSLContext) -> None: client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) client_transport = client_ssl.transport_stream server_transport = server_ssl.transport_stream @@ -889,7 +922,7 @@ async def server() -> None: nursery.start_soon(server) -async def test_closing_nice_case(client_ctx) -> None: +async def test_closing_nice_case(client_ctx: SSLContext) -> None: # the nice case: graceful closes all around client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) @@ -951,14 +984,14 @@ async def expect_eof_server() -> None: nursery.start_soon(expect_eof_server) -async def test_send_all_fails_in_the_middle(client_ctx) -> None: +async def test_send_all_fails_in_the_middle(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair(client_ctx) async with _core.open_nursery() as nursery: nursery.start_soon(client.do_handshake) nursery.start_soon(server.do_handshake) - async def bad_hook(): + async def bad_hook() -> NoReturn: raise KeyError client.transport_stream.send_stream.send_all_hook = bad_hook @@ -982,7 +1015,7 @@ def close_hook() -> None: assert closed == 2 -async def test_ssl_over_ssl(client_ctx) -> None: +async def test_ssl_over_ssl(client_ctx: SSLContext) -> None: client_0, server_0 = memory_stream_pair() client_1 = SSLStream( @@ -1008,7 +1041,7 @@ async def server() -> None: nursery.start_soon(server) -async def test_ssl_bad_shutdown(client_ctx) -> None: +async def test_ssl_bad_shutdown(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair(client_ctx) async with _core.open_nursery() as nursery: @@ -1025,7 +1058,7 @@ async def test_ssl_bad_shutdown(client_ctx) -> None: await server.aclose() -async def test_ssl_bad_shutdown_but_its_ok(client_ctx) -> None: +async def test_ssl_bad_shutdown_but_its_ok(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": True}, @@ -1064,7 +1097,7 @@ async def test_ssl_handshake_failure_during_aclose() -> None: await s.aclose() -async def test_ssl_only_closes_stream_once(client_ctx) -> None: +async def test_ssl_only_closes_stream_once(client_ctx: SSLContext) -> None: # We used to have a bug where if transport_stream.aclose() raised an # error, we would call it again. This checks that that's fixed. client, server = ssl_memory_stream_pair(client_ctx) @@ -1076,8 +1109,9 @@ async def test_ssl_only_closes_stream_once(client_ctx) -> None: client_orig_close_hook = client.transport_stream.send_stream.close_hook transport_close_count = 0 - def close_hook(): + def close_hook() -> NoReturn: nonlocal transport_close_count + assert client_orig_close_hook is not None client_orig_close_hook() transport_close_count += 1 raise KeyError @@ -1089,7 +1123,7 @@ def close_hook(): assert transport_close_count == 1 -async def test_ssl_https_compatibility_disagreement(client_ctx) -> None: +async def test_ssl_https_compatibility_disagreement(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": False}, @@ -1113,7 +1147,7 @@ async def receive_and_expect_error() -> None: nursery.start_soon(receive_and_expect_error) -async def test_https_mode_eof_before_handshake(client_ctx) -> None: +async def test_https_mode_eof_before_handshake(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": True}, @@ -1128,10 +1162,10 @@ async def server_expect_clean_eof() -> None: nursery.start_soon(server_expect_clean_eof) -async def test_send_error_during_handshake(client_ctx) -> None: +async def test_send_error_during_handshake(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair(client_ctx) - async def bad_hook(): + async def bad_hook() -> NoReturn: raise KeyError client.transport_stream.send_stream.send_all_hook = bad_hook @@ -1145,15 +1179,15 @@ async def bad_hook(): await client.do_handshake() -async def test_receive_error_during_handshake(client_ctx) -> None: +async def test_receive_error_during_handshake(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair(client_ctx) - async def bad_hook(): + async def bad_hook() -> NoReturn: raise KeyError client.transport_stream.receive_stream.receive_some_hook = bad_hook - async def client_side(cancel_scope) -> None: + async def client_side(cancel_scope: CancelScope) -> None: with pytest.raises(KeyError): with assert_checkpoints(): await client.do_handshake() @@ -1168,7 +1202,7 @@ async def client_side(cancel_scope) -> None: await client.do_handshake() -async def test_selected_alpn_protocol_before_handshake(client_ctx) -> None: +async def test_selected_alpn_protocol_before_handshake(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1178,7 +1212,7 @@ async def test_selected_alpn_protocol_before_handshake(client_ctx) -> None: server.selected_alpn_protocol() -async def test_selected_alpn_protocol_when_not_set(client_ctx) -> None: +async def test_selected_alpn_protocol_when_not_set(client_ctx: SSLContext) -> None: # ALPN protocol still returns None when it's not set, # instead of raising an exception client, server = ssl_memory_stream_pair(client_ctx) @@ -1193,7 +1227,7 @@ async def test_selected_alpn_protocol_when_not_set(client_ctx) -> None: assert client.selected_alpn_protocol() == server.selected_alpn_protocol() -async def test_selected_npn_protocol_before_handshake(client_ctx) -> None: +async def test_selected_npn_protocol_before_handshake(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1207,7 +1241,7 @@ async def test_selected_npn_protocol_before_handshake(client_ctx) -> None: r"ignore: ssl module. NPN is deprecated, use ALPN instead:UserWarning", r"ignore:ssl NPN is deprecated, use ALPN instead:DeprecationWarning", ) -async def test_selected_npn_protocol_when_not_set(client_ctx) -> None: +async def test_selected_npn_protocol_when_not_set(client_ctx: SSLContext) -> None: # NPN protocol still returns None when it's not set, # instead of raising an exception client, server = ssl_memory_stream_pair(client_ctx) @@ -1222,7 +1256,7 @@ async def test_selected_npn_protocol_when_not_set(client_ctx) -> None: assert client.selected_npn_protocol() == server.selected_npn_protocol() -async def test_get_channel_binding_before_handshake(client_ctx) -> None: +async def test_get_channel_binding_before_handshake(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1232,7 +1266,7 @@ async def test_get_channel_binding_before_handshake(client_ctx) -> None: server.get_channel_binding() -async def test_get_channel_binding_after_handshake(client_ctx) -> None: +async def test_get_channel_binding_after_handshake(client_ctx: SSLContext) -> None: client, server = ssl_memory_stream_pair(client_ctx) async with _core.open_nursery() as nursery: @@ -1245,7 +1279,7 @@ async def test_get_channel_binding_after_handshake(client_ctx) -> None: assert client.get_channel_binding() == server.get_channel_binding() -async def test_getpeercert(client_ctx) -> None: +async def test_getpeercert(client_ctx: SSLContext) -> None: # Make sure we're not affected by https://bugs.python.org/issue29334 client, server = ssl_memory_stream_pair(client_ctx) @@ -1258,8 +1292,10 @@ async def test_getpeercert(client_ctx) -> None: assert ("DNS", "trio-test-1.example.org") in client.getpeercert()["subjectAltName"] -async def test_SSLListener(client_ctx) -> None: - async def setup(**kwargs): +async def test_SSLListener(client_ctx: SSLContext) -> None: + async def setup( + **kwargs: Any, + ) -> tuple[tsocket.SocketType, SSLListener, SSLStream[SocketStream]]: listen_sock = tsocket.socket() await listen_sock.bind(("127.0.0.1", 0)) listen_sock.listen(1) diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index 3ebf2589bc..b0c93cdf7a 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, AsyncIterator import pytest -from pytest import MonkeyPatch +from pytest import MonkeyPatch, WarningsRecorder from .. import ( ClosedResourceError, @@ -153,7 +153,7 @@ async def test_multi_wait(background_process) -> None: # Test for deprecated 'async with process:' semantics -async def test_async_with_basics_deprecated(recwarn) -> None: +async def test_async_with_basics_deprecated(recwarn: WarningsRecorder) -> None: async with await open_process( CAT, stdin=subprocess.PIPE, stdout=subprocess.PIPE ) as proc: @@ -168,7 +168,7 @@ async def test_async_with_basics_deprecated(recwarn) -> None: # Test for deprecated 'async with process:' semantics -async def test_kill_when_context_cancelled(recwarn) -> None: +async def test_kill_when_context_cancelled(recwarn: WarningsRecorder) -> None: with move_on_after(100) as scope: async with await open_process(SLEEP(10)) as proc: assert proc.poll() is None diff --git a/trio/_tests/test_sync.py b/trio/_tests/test_sync.py index 91c865e085..448ead15da 100644 --- a/trio/_tests/test_sync.py +++ b/trio/_tests/test_sync.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import weakref +from typing import Callable, Union import pytest +from typing_extensions import TypeAlias from .. import _core from .._sync import * @@ -204,7 +208,7 @@ async def test_Semaphore() -> None: record = [] - async def do_acquire(s) -> None: + async def do_acquire(s: Semaphore) -> None: record.append("started") await s.acquire() record.append("finished") @@ -240,7 +244,9 @@ async def test_Semaphore_bounded() -> None: @pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__) -async def test_Lock_and_StrictFIFOLock(lockcls) -> None: +async def test_Lock_and_StrictFIFOLock( + lockcls: type[Lock] | type[StrictFIFOLock], +) -> None: l = lockcls() # noqa assert not l.locked() @@ -349,7 +355,7 @@ async def test_Condition() -> None: finished_waiters = set() - async def waiter(i) -> None: + async def waiter(i: int) -> None: async with c: await c.wait() finished_waiters.add(i) @@ -489,17 +495,28 @@ def release(self) -> None: "lock_factory", lock_factories, ids=lock_factory_names ) +LockLike: TypeAlias = Union[ + CapacityLimiter, + Semaphore, + Lock, + StrictFIFOLock, + ChannelLock1, + ChannelLock2, + ChannelLock3, +] +LockFactory: TypeAlias = Callable[[], LockLike] + # Spawn a bunch of workers that take a lock and then yield; make sure that # only one worker is ever in the critical section at a time. @generic_lock_test -async def test_generic_lock_exclusion(lock_factory) -> None: +async def test_generic_lock_exclusion(lock_factory: LockFactory) -> None: LOOPS = 10 WORKERS = 5 in_critical_section = False acquires = 0 - async def worker(lock_like) -> None: + async def worker(lock_like: LockLike) -> None: nonlocal in_critical_section, acquires for _ in range(LOOPS): async with lock_like: @@ -522,12 +539,12 @@ async def worker(lock_like) -> None: # Several workers queue on the same lock; make sure they each get it, in # order. @generic_lock_test -async def test_generic_lock_fifo_fairness(lock_factory) -> None: +async def test_generic_lock_fifo_fairness(lock_factory: LockFactory) -> None: initial_order = [] record = [] LOOPS = 5 - async def loopy(name: str, lock_like) -> None: + async def loopy(name: str, lock_like: LockLike) -> None: # Record the order each task was initially scheduled in initial_order.append(name) for _ in range(LOOPS): @@ -546,7 +563,9 @@ async def loopy(name: str, lock_like) -> None: @generic_lock_test -async def test_generic_lock_acquire_nowait_blocks_acquire(lock_factory) -> None: +async def test_generic_lock_acquire_nowait_blocks_acquire( + lock_factory: LockFactory, +) -> None: lock_like = lock_factory() record = [] diff --git a/trio/_tests/test_testing.py b/trio/_tests/test_testing.py index 7461de058c..1c36567560 100644 --- a/trio/_tests/test_testing.py +++ b/trio/_tests/test_testing.py @@ -1,12 +1,17 @@ -# XX this should get broken up, like testing.py did +from __future__ import annotations +# XX this should get broken up, like testing.py did import tempfile import pytest +from pytest import WarningsRecorder + +from trio import Nursery +from trio.abc import ReceiveStream, SendStream from .. import _core, sleep, socket as tsocket from .._core._tests.tutil import can_bind_ipv6 -from .._highlevel_generic import aclose_forcefully +from .._highlevel_generic import StapledStream, aclose_forcefully from .._highlevel_socket import SocketListener from ..testing import * from ..testing._check_streams import _assert_raises @@ -104,7 +109,7 @@ async def wait_big_cushion() -> None: ################################################################ -async def test_assert_checkpoints(recwarn) -> None: +async def test_assert_checkpoints(recwarn: WarningsRecorder) -> None: with assert_checkpoints(): await _core.checkpoint() @@ -130,7 +135,7 @@ async def test_assert_checkpoints(recwarn) -> None: await _core.cancel_shielded_checkpoint() -async def test_assert_no_checkpoints(recwarn) -> None: +async def test_assert_no_checkpoints(recwarn: WarningsRecorder) -> None: with assert_no_checkpoints(): 1 + 1 @@ -163,11 +168,11 @@ async def test_assert_no_checkpoints(recwarn) -> None: async def test_Sequencer() -> None: record = [] - def t(val) -> None: + def t(val: object) -> None: print(val) record.append(val) - async def f1(seq) -> None: + async def f1(seq: Sequencer) -> None: async with seq(1): t(("f1", 1)) async with seq(3): @@ -175,7 +180,7 @@ async def f1(seq) -> None: async with seq(4): t(("f1", 4)) - async def f2(seq) -> None: + async def f2(seq: Sequencer) -> None: async with seq(0): t(("f2", 0)) async with seq(2): @@ -203,7 +208,7 @@ async def test_Sequencer_cancel() -> None: record = [] seq = Sequencer() - async def child(i) -> None: + async def child(i: int) -> None: with _core.CancelScope() as scope: if i == 1: scope.cancel() @@ -230,7 +235,7 @@ async def child(i) -> None: ################################################################ -async def test__assert_raises(): +async def test__assert_raises() -> None: with pytest.raises(AssertionError): with _assert_raises(RuntimeError): 1 + 1 @@ -273,11 +278,11 @@ async def test__UnboundeByteQueue() -> None: with assert_checkpoints(): assert await ubq.get() == b"efghi" - async def putter(data) -> None: + async def putter(data: bytes) -> None: await wait_all_tasks_blocked() ubq.put(data) - async def getter(expect) -> None: + async def getter(expect: bytes) -> None: with assert_checkpoints(): assert await ubq.get() == expect @@ -320,7 +325,7 @@ async def closer() -> None: async def test_MemorySendStream() -> None: mss = MemorySendStream() - async def do_send_all(data) -> None: + async def do_send_all(data: bytes) -> None: with assert_checkpoints(): await mss.send_all(data) @@ -410,7 +415,7 @@ def close_hook() -> None: async def test_MemoryReceiveStream() -> None: mrs = MemoryReceiveStream() - async def do_receive_some(max_bytes): + async def do_receive_some(max_bytes: int | None) -> bytes: with assert_checkpoints(): return await mrs.receive_some(max_bytes) @@ -521,7 +526,7 @@ async def test_memory_stream_one_way_pair() -> None: await s.send_all(b"123") assert await r.receive_some(10) == b"123" - async def receiver(expected) -> None: + async def receiver(expected: bytes) -> None: assert await r.receive_some(10) == expected # This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook @@ -549,7 +554,7 @@ async def receiver(expected) -> None: s.send_all_hook = None await s.send_all(b"456") - async def cancel_after_idle(nursery) -> None: + async def cancel_after_idle(nursery: Nursery) -> None: await wait_all_tasks_blocked() nursery.cancel_scope.cancel() @@ -591,31 +596,37 @@ async def receiver() -> None: async def test_memory_streams_with_generic_tests() -> None: - async def one_way_stream_maker(): + async def one_way_stream_maker() -> tuple[MemorySendStream, MemoryReceiveStream]: return memory_stream_one_way_pair() await check_one_way_stream(one_way_stream_maker, None) - async def half_closeable_stream_maker(): + async def half_closeable_stream_maker() -> tuple[ + StapledStream[MemorySendStream, MemoryReceiveStream], + StapledStream[MemorySendStream, MemoryReceiveStream], + ]: return memory_stream_pair() await check_half_closeable_stream(half_closeable_stream_maker, None) async def test_lockstep_streams_with_generic_tests() -> None: - async def one_way_stream_maker(): + async def one_way_stream_maker() -> tuple[SendStream, ReceiveStream]: return lockstep_stream_one_way_pair() await check_one_way_stream(one_way_stream_maker, one_way_stream_maker) - async def two_way_stream_maker(): + async def two_way_stream_maker() -> tuple[ + StapledStream[SendStream, ReceiveStream], + StapledStream[SendStream, ReceiveStream], + ]: return lockstep_stream_pair() await check_two_way_stream(two_way_stream_maker, two_way_stream_maker) async def test_open_stream_to_socket_listener() -> None: - async def check(listener) -> None: + async def check(listener: SocketListener) -> None: async with listener: client_stream = await open_stream_to_socket_listener(listener) async with client_stream: diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 0fe5d8dc48..80b870445d 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -360,7 +360,9 @@ async def child(q, cancellable): # Make sure that if trio.run exits, and then the thread finishes, then that's # handled gracefully. (Requires that the thread result machinery be prepared # for call_soon to raise RunFinishedError.) -def test_run_in_worker_thread_abandoned(capfd, monkeypatch: MonkeyPatch) -> None: +def test_run_in_worker_thread_abandoned( + capfd: pytest.CaptureFixture[str], monkeypatch: MonkeyPatch +) -> None: monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01) q1 = stdlib_queue.Queue[None]() diff --git a/trio/_tests/test_util.py b/trio/_tests/test_util.py index f074a06096..01a7503e9d 100644 --- a/trio/_tests/test_util.py +++ b/trio/_tests/test_util.py @@ -1,6 +1,7 @@ import signal import sys import types +from typing import Any, TypeVar import pytest @@ -23,11 +24,13 @@ ) from ..testing import wait_all_tasks_blocked +T = TypeVar("T") + def test_signal_raise() -> None: record = [] - def handler(signum, _) -> None: + def handler(signum: int, _: object) -> None: record.append(signum) old = signal.signal(signal.SIGFPE, handler) @@ -114,9 +117,9 @@ async def f() -> None: # pragma: no cover import asyncio if sys.version_info < (3, 11): - - @asyncio.coroutine - def generator_based_coro(): # pragma: no cover + # not bothering to type this one + @asyncio.coroutine # type: ignore[misc] + def generator_based_coro() -> Any: # pragma: no cover yield from asyncio.sleep(1) with pytest.raises(TypeError) as excinfo: @@ -145,7 +148,7 @@ def generator_based_coro(): # pragma: no cover assert "appears to be synchronous" in str(excinfo.value) - async def async_gen(arg): # pragma: no cover + async def async_gen(_: object) -> Any: # pragma: no cover yield # does not give arg-type typing error @@ -160,7 +163,7 @@ async def async_gen(arg): # pragma: no cover def test_generic_function() -> None: @generic_function - def test_func(arg): + def test_func(arg: T) -> T: """Look, a docstring!""" return arg diff --git a/trio/_tests/test_wait_for_object.py b/trio/_tests/test_wait_for_object.py index 44790497ed..b41bcba3a5 100644 --- a/trio/_tests/test_wait_for_object.py +++ b/trio/_tests/test_wait_for_object.py @@ -153,7 +153,7 @@ async def test_WaitForSingleObject() -> None: # Not a handle with pytest.raises(TypeError): - await WaitForSingleObject("not a handle") # Wrong type + await WaitForSingleObject("not a handle") # type: ignore[arg-type] # Wrong type # with pytest.raises(OSError): # await WaitForSingleObject(99) # If you're unlucky, it actually IS a handle :( print("test_WaitForSingleObject not a handle OK") diff --git a/trio/_tests/tools/test_gen_exports.py b/trio/_tests/tools/test_gen_exports.py index ee51827a59..4bd8b7c089 100644 --- a/trio/_tests/tools/test_gen_exports.py +++ b/trio/_tests/tools/test_gen_exports.py @@ -1,5 +1,6 @@ import ast import sys +from pathlib import Path import pytest @@ -91,7 +92,7 @@ def test_create_pass_through_args() -> None: @skip_lints @pytest.mark.parametrize("imports", ["", IMPORT_1, IMPORT_2, IMPORT_3]) -def test_process(tmp_path, imports) -> None: +def test_process(tmp_path: Path, imports: str) -> None: try: import black # noqa: F401 # there's no dedicated CI run that has astor+isort, but lacks black. @@ -123,7 +124,7 @@ def test_process(tmp_path, imports) -> None: @skip_lints -def test_run_black(tmp_path) -> None: +def test_run_black(tmp_path: Path) -> None: """Test that processing properly fails if black does.""" try: import black # noqa: F401 @@ -140,7 +141,7 @@ def test_run_black(tmp_path) -> None: @skip_lints -def test_run_ruff(tmp_path) -> None: +def test_run_ruff(tmp_path: Path) -> None: """Test that processing properly fails if black does.""" try: import ruff # noqa: F401 @@ -166,7 +167,7 @@ def test_run_ruff(tmp_path) -> None: @skip_lints -def test_lint_failure(tmp_path) -> None: +def test_lint_failure(tmp_path: Path) -> None: """Test that processing properly fails if black or ruff does.""" try: import black # noqa: F401 From fdd3df3169c7039c0add1e65dc61d270fd54f6e0 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 18 Oct 2023 15:11:51 +0200 Subject: [PATCH 09/35] type test_file_io --- pyproject.toml | 1 - trio/_tests/test_file_io.py | 40 +++++++++++++++++++++++++------------ 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 831d577e65..0e4b7671f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,6 @@ module = [ "trio/_core/_tests/test_guest_mode", "trio/_core/_tests/test_ki", "trio/_core/_tests/test_multierror", -"trio/_tests/test_file_io", "trio/_tests/test_subprocess", "trio/_tests/test_threads", ] diff --git a/trio/_tests/test_file_io.py b/trio/_tests/test_file_io.py index 2ef02f1145..d438a9fb10 100644 --- a/trio/_tests/test_file_io.py +++ b/trio/_tests/test_file_io.py @@ -20,12 +20,12 @@ def path(tmp_path: pathlib.Path) -> str: @pytest.fixture -def wrapped(): +def wrapped() -> mock.Mock: return mock.Mock(spec_set=io.StringIO) @pytest.fixture -def async_file(wrapped): +def async_file(wrapped: mock.Mock) -> AsyncIOWrapper[mock.Mock]: return trio.wrap_file(wrapped) @@ -54,11 +54,15 @@ def write(self) -> None: # pragma: no cover trio.wrap_file(FakeFile()) -def test_wrapped_property(async_file, wrapped) -> None: +def test_wrapped_property( + async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock +) -> None: assert async_file.wrapped is wrapped -def test_dir_matches_wrapped(async_file, wrapped) -> None: +def test_dir_matches_wrapped( + async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock +) -> None: attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS) # all supported attrs in wrapped should be available in async_file @@ -122,7 +126,9 @@ def test_type_stubs_match_lists() -> None: assert found == expected -def test_sync_attrs_forwarded(async_file, wrapped) -> None: +def test_sync_attrs_forwarded( + async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock +) -> None: for attr_name in _FILE_SYNC_ATTRS: if attr_name not in dir(async_file): continue @@ -130,7 +136,9 @@ def test_sync_attrs_forwarded(async_file, wrapped) -> None: assert getattr(async_file, attr_name) is getattr(wrapped, attr_name) -def test_sync_attrs_match_wrapper(async_file, wrapped) -> None: +def test_sync_attrs_match_wrapper( + async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock +) -> None: for attr_name in _FILE_SYNC_ATTRS: if attr_name in dir(async_file): continue @@ -142,7 +150,7 @@ def test_sync_attrs_match_wrapper(async_file, wrapped) -> None: getattr(wrapped, attr_name) -def test_async_methods_generated_once(async_file) -> None: +def test_async_methods_generated_once(async_file: AsyncIOWrapper[mock.Mock]) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name not in dir(async_file): continue @@ -150,15 +158,19 @@ def test_async_methods_generated_once(async_file) -> None: assert getattr(async_file, meth_name) is getattr(async_file, meth_name) -def test_async_methods_signature(async_file) -> None: +# I gave up on typing this one +def test_async_methods_signature(async_file: AsyncIOWrapper[mock.Mock]) -> None: # use read as a representative of all async methods assert async_file.read.__name__ == "read" assert async_file.read.__qualname__ == "AsyncIOWrapper.read" + assert async_file.read.__doc__ is not None assert "io.StringIO.read" in async_file.read.__doc__ -async def test_async_methods_wrap(async_file, wrapped) -> None: +async def test_async_methods_wrap( + async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock +) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name not in dir(async_file): continue @@ -176,7 +188,9 @@ async def test_async_methods_wrap(async_file, wrapped) -> None: wrapped.reset_mock() -async def test_async_methods_match_wrapper(async_file, wrapped) -> None: +async def test_async_methods_match_wrapper( + async_file: AsyncIOWrapper[mock.Mock], wrapped: mock.Mock +) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name in dir(async_file): continue @@ -188,7 +202,7 @@ async def test_async_methods_match_wrapper(async_file, wrapped) -> None: getattr(wrapped, meth_name) -async def test_open(path) -> None: +async def test_open(path: pathlib.Path) -> None: f = await trio.open_file(path, "w") assert isinstance(f, AsyncIOWrapper) @@ -196,7 +210,7 @@ async def test_open(path) -> None: await f.aclose() -async def test_open_context_manager(path) -> None: +async def test_open_context_manager(path: pathlib.Path) -> None: async with await trio.open_file(path, "w") as f: assert isinstance(f, AsyncIOWrapper) assert not f.closed @@ -216,7 +230,7 @@ async def test_async_iter() -> None: assert result == expected -async def test_aclose_cancelled(path) -> None: +async def test_aclose_cancelled(path: pathlib.Path) -> None: with _core.CancelScope() as cscope: f = await trio.open_file(path, "w") cscope.cancel() From 55ece97dc7cd79360ac4ebab008ea17a48db9fbf Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 18 Oct 2023 15:38:49 +0200 Subject: [PATCH 10/35] type test_subprocess --- pyproject.toml | 1 - trio/_tests/test_subprocess.py | 90 +++++++++++++++++++++------------- 2 files changed, 55 insertions(+), 36 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0e4b7671f9..6d4dc673af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,6 @@ module = [ "trio/_core/_tests/test_guest_mode", "trio/_core/_tests/test_ki", "trio/_core/_tests/test_multierror", -"trio/_tests/test_subprocess", "trio/_tests/test_threads", ] disallow_any_decorated = false diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index b0c93cdf7a..3da8e48d63 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -8,10 +8,19 @@ from contextlib import asynccontextmanager from functools import partial from pathlib import Path as SyncPath -from typing import TYPE_CHECKING, AsyncIterator +from signal import Signals +from typing import ( + TYPE_CHECKING, + Any, + AsyncContextManager, + AsyncIterator, + Callable, + NoReturn, +) import pytest from pytest import MonkeyPatch, WarningsRecorder +from typing_extensions import TypeAlias from .. import ( ClosedResourceError, @@ -24,18 +33,21 @@ sleep, sleep_forever, ) +from .._abc import Stream from .._core._tests.tutil import skip_if_fbsd_pipes_broken, slow from ..lowlevel import open_process -from ..testing import assert_no_checkpoints, wait_all_tasks_blocked +from ..testing import MockClock, assert_no_checkpoints, wait_all_tasks_blocked + +if sys.platform == "win32": + SignalType: TypeAlias = None +else: + SignalType: TypeAlias = Signals -if TYPE_CHECKING: - ... - from signal import Signals +SIGKILL: SignalType +SIGTERM: SignalType +SIGUSR1: SignalType posix = os.name == "posix" -SIGKILL: Signals | None -SIGTERM: Signals | None -SIGUSR1: Signals | None if (not TYPE_CHECKING and posix) or sys.platform != "win32": from signal import SIGKILL, SIGTERM, SIGUSR1 else: @@ -64,15 +76,15 @@ def SLEEP(seconds: int) -> list[str]: return python(f"import time; time.sleep({seconds})") -def got_signal(proc, sig): +def got_signal(proc: Process, sig: SignalType) -> bool: if posix: return proc.returncode == -sig else: return proc.returncode != 0 -@asynccontextmanager -async def open_process_then_kill(*args, **kwargs): +@asynccontextmanager # type: ignore[misc] # Any in decorator +async def open_process_then_kill(*args: Any, **kwargs: Any) -> AsyncIterator[Process]: proc = await open_process(*args, **kwargs) try: yield proc @@ -81,16 +93,11 @@ async def open_process_then_kill(*args, **kwargs): await proc.wait() -# not entirely sure about this annotation -@asynccontextmanager -async def run_process_in_nursery( - *args, **kwargs -) -> AsyncIterator[subprocess.CompletedProcess[bytes]]: +@asynccontextmanager # type: ignore[misc] # Any in decorator +async def run_process_in_nursery(*args: Any, **kwargs: Any) -> AsyncIterator[Process]: async with _core.open_nursery() as nursery: kwargs.setdefault("check", False) - proc: subprocess.CompletedProcess[bytes] = await nursery.start( - partial(run_process, *args, **kwargs) - ) + proc: Process = await nursery.start(partial(run_process, *args, **kwargs)) yield proc nursery.cancel_scope.cancel() @@ -101,9 +108,11 @@ async def run_process_in_nursery( ids=["open_process", "run_process in nursery"], ) +BackgroundProcessType: TypeAlias = Callable[..., AsyncContextManager[Process]] + @background_process_param -async def test_basic(background_process) -> None: +async def test_basic(background_process: BackgroundProcessType) -> None: async with background_process(EXIT_TRUE) as proc: await proc.wait() assert isinstance(proc, Process) @@ -120,7 +129,9 @@ async def test_basic(background_process) -> None: @background_process_param -async def test_auto_update_returncode(background_process) -> None: +async def test_auto_update_returncode( + background_process: BackgroundProcessType, +) -> None: async with background_process(SLEEP(9999)) as p: assert p.returncode is None assert "running" in repr(p) @@ -133,7 +144,7 @@ async def test_auto_update_returncode(background_process) -> None: @background_process_param -async def test_multi_wait(background_process) -> None: +async def test_multi_wait(background_process: BackgroundProcessType) -> None: async with background_process(SLEEP(10)) as proc: # Check that wait (including multi-wait) tolerates being cancelled async with _core.open_nursery() as nursery: @@ -189,7 +200,7 @@ async def test_kill_when_context_cancelled(recwarn: WarningsRecorder) -> None: @background_process_param -async def test_pipes(background_process) -> None: +async def test_pipes(background_process: BackgroundProcessType) -> None: async with background_process( COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, stdin=subprocess.PIPE, @@ -199,10 +210,12 @@ async def test_pipes(background_process) -> None: msg = b"the quick brown fox jumps over the lazy dog" async def feed_input() -> None: + assert proc.stdin is not None await proc.stdin.send_all(msg) await proc.stdin.aclose() - async def check_output(stream, expected) -> None: + async def check_output(stream: Stream, expected: bytes) -> None: + assert type(stream) is None seen = bytearray() async for chunk in stream: seen += chunk @@ -220,7 +233,7 @@ async def check_output(stream, expected) -> None: @background_process_param -async def test_interactive(background_process) -> None: +async def test_interactive(background_process: BackgroundProcessType) -> None: # Test some back-and-forth with a subprocess. This one works like so: # in: 32\n # out: 0000...0000\n (32 zeroes) @@ -249,10 +262,10 @@ async def test_interactive(background_process) -> None: ) as proc: newline = b"\n" if posix else b"\r\n" - async def expect(idx: int, request) -> None: + async def expect(idx: int, request: int) -> None: async with _core.open_nursery() as nursery: - async def drain_one(stream, count: int, digit) -> None: + async def drain_one(stream: Stream, count: int, digit: int) -> None: while count > 0: result = await stream.receive_some(count) assert result == (f"{digit}".encode() * len(result)) @@ -263,6 +276,9 @@ async def drain_one(stream, count: int, digit) -> None: nursery.start_soon(drain_one, proc.stdout, request, idx * 2) nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1) + assert proc.stdin is not None + assert proc.stdout is not None + assert proc.stderr is not None with fail_after(5): await proc.stdin.send_all(b"12") await sleep(0.1) @@ -358,13 +374,14 @@ async def test_run_with_broken_pipe() -> None: @background_process_param -async def test_stderr_stdout(background_process) -> None: +async def test_stderr_stdout(background_process: BackgroundProcessType) -> None: async with background_process( COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) as proc: + assert proc.stdio is not None assert proc.stdout is not None assert proc.stderr is None await proc.stdio.send_all(b"1234") @@ -438,14 +455,17 @@ async def test_errors() -> None: @background_process_param -async def test_signals(background_process) -> None: - async def test_one_signal(send_it, signum) -> None: +async def test_signals(background_process: BackgroundProcessType) -> None: + async def test_one_signal( + send_it: Callable[[Process], None], signum: signal.Signals | None + ) -> None: with move_on_after(1.0) as scope: async with background_process(SLEEP(3600)) as proc: send_it(proc) await proc.wait() assert not scope.cancelled_caught if posix: + assert signum is not None assert proc.returncode == -signum else: assert proc.returncode != 0 @@ -465,7 +485,7 @@ async def test_one_signal(send_it, signum) -> None: @pytest.mark.skipif(not posix, reason="POSIX specific") @background_process_param -async def test_wait_reapable_fails(background_process) -> None: +async def test_wait_reapable_fails(background_process: BackgroundProcessType) -> None: old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN) try: # With SIGCHLD disabled, the wait() syscall will wait for the @@ -496,7 +516,7 @@ def test_waitid_eintr() -> None: got_alarm = False sleeper = subprocess.Popen(["sleep", "3600"]) - def on_alarm(sig, frame) -> None: + def on_alarm(sig: object, frame: object) -> None: nonlocal got_alarm got_alarm = True sleeper.kill() @@ -518,7 +538,7 @@ def on_alarm(sig, frame) -> None: async def test_custom_deliver_cancel() -> None: custom_deliver_cancel_called = False - async def custom_deliver_cancel(proc) -> None: + async def custom_deliver_cancel(proc: Process) -> None: nonlocal custom_deliver_cancel_called custom_deliver_cancel_called = True proc.terminate() @@ -542,7 +562,7 @@ async def custom_deliver_cancel(proc) -> None: async def test_warn_on_failed_cancel_terminate(monkeypatch: MonkeyPatch) -> None: original_terminate = Process.terminate - def broken_terminate(self): + def broken_terminate(self: Process) -> NoReturn: original_terminate(self) raise OSError("whoops") @@ -557,7 +577,7 @@ def broken_terminate(self): @pytest.mark.skipif(not posix, reason="posix only") async def test_warn_on_cancel_SIGKILL_escalation( - autojump_clock, monkeypatch: MonkeyPatch + autojump_clock: MockClock, monkeypatch: MonkeyPatch ) -> None: monkeypatch.setattr(Process, "terminate", lambda *args: None) From 6d19cd2a9ed04a346bc36d12301b13dacd08d1e1 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 18 Oct 2023 15:59:45 +0200 Subject: [PATCH 11/35] type test_threads --- pyproject.toml | 1 - trio/_tests/test_threads.py | 89 +++++++++++++++++++++++-------------- 2 files changed, 56 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6d4dc673af..5e3f301d63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,6 @@ module = [ "trio/_core/_tests/test_guest_mode", "trio/_core/_tests/test_ki", "trio/_core/_tests/test_multierror", -"trio/_tests/test_threads", ] disallow_any_decorated = false disallow_incomplete_defs = false diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 80b870445d..da75589840 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -8,7 +8,19 @@ import time import weakref from functools import partial -from typing import TYPE_CHECKING, Callable, Optional +from typing import ( + TYPE_CHECKING, + AsyncGenerator, + Awaitable, + Callable, + List, + NoReturn, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import pytest import sniffio @@ -16,7 +28,7 @@ from trio._core import TrioToken, current_trio_token -from .. import CapacityLimiter, Event, _core, sleep +from .. import CancelScope, CapacityLimiter, Event, _core, sleep from .._core._tests.test_ki import ki_self from .._core._tests.tutil import buggy_pypy_asyncgens from .._threads import ( @@ -25,14 +37,23 @@ from_thread_run_sync, to_thread_run_sync, ) +from ..lowlevel import Task from ..testing import wait_all_tasks_blocked +RecordType = List[Tuple[str, Union[threading.Thread, Type[BaseException]]]] +T = TypeVar("T") + async def test_do_in_trio_thread() -> None: trio_thread = threading.current_thread() - async def check_case(do_in_trio_thread, fn, expected, trio_token=None) -> None: - record: list[tuple[str, threading.Thread | type[BaseException]]] = [] + async def check_case( + do_in_trio_thread: Callable[..., threading.Thread], + fn: Callable[..., T | Awaitable[T]], + expected: tuple[str, T], + trio_token: TrioToken | None = None, + ) -> None: + record: RecordType = [] def threadfn() -> None: try: @@ -52,21 +73,21 @@ def threadfn() -> None: token = _core.current_trio_token() - def f1(record) -> int: + def f1(record: RecordType) -> int: assert not _core.currently_ki_protected() record.append(("f", threading.current_thread())) return 2 await check_case(from_thread_run_sync, f1, ("got", 2), trio_token=token) - def f2(record): + def f2(record: RecordType) -> NoReturn: assert not _core.currently_ki_protected() record.append(("f", threading.current_thread())) raise ValueError await check_case(from_thread_run_sync, f2, ("error", ValueError), trio_token=token) - async def f3(record) -> int: + async def f3(record: RecordType) -> int: assert not _core.currently_ki_protected() await _core.checkpoint() record.append(("f", threading.current_thread())) @@ -74,7 +95,7 @@ async def f3(record) -> int: await check_case(from_thread_run, f3, ("got", 3), trio_token=token) - async def f4(record): + async def f4(record: RecordType) -> NoReturn: assert not _core.currently_ki_protected() await _core.checkpoint() record.append(("f", threading.current_thread())) @@ -151,13 +172,13 @@ async def trio_fn() -> None: ev.set() await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) - def thread_fn(token) -> None: + def thread_fn(token: TrioToken) -> None: try: from_thread_run(trio_fn, trio_token=token) except _core.Cancelled: record.append("cancelled") - async def main(): + async def main() -> threading.Thread: token = _core.current_trio_token() thread = threading.Thread(target=thread_fn, args=(token,)) thread.start() @@ -284,14 +305,14 @@ async def test_has_pthread_setname_np() -> None: async def test_run_in_worker_thread() -> None: trio_thread = threading.current_thread() - def f(x): + def f(x: T) -> tuple[T, threading.Thread]: return (x, threading.current_thread()) x, child_thread = await to_thread_run_sync(f, 1) assert x == 1 assert child_thread != trio_thread - def g(): + def g() -> NoReturn: raise ValueError(threading.current_thread()) with pytest.raises(ValueError) as excinfo: @@ -303,13 +324,13 @@ def g(): async def test_run_in_worker_thread_cancellation() -> None: register: list[str | None] = [None] - def f(q) -> None: + def f(q: stdlib_queue.Queue[str]) -> None: # Make the thread block for a controlled amount of time register[0] = "blocking" q.get() register[0] = "finished" - async def child(q, cancellable): + async def child(q: stdlib_queue.Queue[None], cancellable: bool) -> None: record.append("start") try: return await to_thread_run_sync(f, q, cancellable=cancellable) @@ -400,7 +421,9 @@ async def child() -> None: @pytest.mark.parametrize("MAX", [3, 5, 10]) @pytest.mark.parametrize("cancel", [False, True]) @pytest.mark.parametrize("use_default_limiter", [False, True]) -async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter) -> None: +async def test_run_in_worker_thread_limiter( + MAX: int, cancel: bool, use_default_limiter: bool +) -> None: # This test is a bit tricky. The goal is to make sure that if we set # limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever # running at a time, even if there are more concurrent calls to @@ -444,7 +467,7 @@ class state: token = _core.current_trio_token() - def thread_fn(cancel_scope) -> None: + def thread_fn(cancel_scope: CancelScope) -> None: print("thread_fn start") from_thread_run_sync(cancel_scope.cancel, trio_token=token) with lock: @@ -460,7 +483,7 @@ def thread_fn(cancel_scope) -> None: state.running -= 1 print("thread_fn exiting") - async def run_thread(event) -> None: + async def run_thread(event: Event) -> None: with _core.CancelScope() as cancel_scope: await to_thread_run_sync( thread_fn, cancel_scope, limiter=limiter_arg, cancellable=cancel @@ -515,11 +538,11 @@ async def test_run_in_worker_thread_custom_limiter() -> None: record = [] class CustomLimiter: - async def acquire_on_behalf_of(self, borrower) -> None: + async def acquire_on_behalf_of(self, borrower: Task) -> None: record.append("acquire") self._borrower = borrower - def release_on_behalf_of(self, borrower) -> None: + def release_on_behalf_of(self, borrower: Task) -> None: record.append("release") assert borrower == self._borrower @@ -533,10 +556,10 @@ async def test_run_in_worker_thread_limiter_error() -> None: record = [] class BadCapacityLimiter: - async def acquire_on_behalf_of(self, borrower) -> None: + async def acquire_on_behalf_of(self, borrower: Task) -> None: record.append("acquire") - def release_on_behalf_of(self, borrower): + def release_on_behalf_of(self, borrower: Task) -> NoReturn: record.append("release") raise ValueError @@ -559,7 +582,7 @@ def release_on_behalf_of(self, borrower): async def test_run_in_worker_thread_fail_to_spawn(monkeypatch: MonkeyPatch) -> None: # Test the unlikely but possible case where trying to spawn a thread fails - def bad_start(self, *args): + def bad_start(self: object, *args: object) -> NoReturn: raise RuntimeError("the engines canna take it captain") monkeypatch.setattr(_core._thread_cache.ThreadCache, "start_thread_soon", bad_start) @@ -578,7 +601,7 @@ def bad_start(self, *args): async def test_trio_to_thread_run_sync_token() -> None: # Test that to_thread_run_sync automatically injects the current trio token # into a spawned thread - def thread_fn(): + def thread_fn() -> TrioToken: callee_token = from_thread_run_sync(_core.current_trio_token) return callee_token @@ -605,7 +628,7 @@ async def test_trio_to_thread_run_sync_contextvars() -> None: trio_thread = threading.current_thread() trio_test_contextvar.set("main") - def f(): + def f() -> tuple[str, threading.Thread]: value = trio_test_contextvar.get() with pytest.raises(sniffio.AsyncLibraryNotFoundError): sniffio.current_async_library() @@ -615,7 +638,7 @@ def f(): assert value == "main" assert child_thread != trio_thread - def g(): + def g() -> tuple[str, str, threading.Thread]: parent_value = trio_test_contextvar.get() trio_test_contextvar.set("worker") inner_value = trio_test_contextvar.get() @@ -641,7 +664,7 @@ def g(): async def test_trio_from_thread_run_sync() -> None: # Test that to_thread_run_sync correctly "hands off" the trio token to # trio.from_thread.run_sync() - def thread_fn_1(): + def thread_fn_1() -> float: trio_time = from_thread_run_sync(_core.current_time) return trio_time @@ -686,7 +709,7 @@ def sync_fn() -> None: # pragma: no cover async def test_trio_from_thread_token() -> None: # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() # share the same Trio token - def thread_fn(): + def thread_fn() -> TrioToken: callee_token = from_thread_run_sync(_core.current_trio_token) return callee_token @@ -698,7 +721,7 @@ def thread_fn(): async def test_trio_from_thread_token_kwarg() -> None: # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can # use an explicitly defined token - def thread_fn(token): + def thread_fn(token: TrioToken) -> TrioToken: callee_token = from_thread_run_sync(_core.current_trio_token, trio_token=token) return callee_token @@ -718,14 +741,14 @@ async def test_from_thread_no_token() -> None: async def test_trio_from_thread_run_sync_contextvars() -> None: trio_test_contextvar.set("main") - def thread_fn(): + def thread_fn() -> tuple[str, str, str, str, str]: thread_parent_value = trio_test_contextvar.get() trio_test_contextvar.set("worker") thread_current_value = trio_test_contextvar.get() with pytest.raises(sniffio.AsyncLibraryNotFoundError): sniffio.current_async_library() - def back_in_main(): + def back_in_main() -> tuple[str, str]: back_parent_value = trio_test_contextvar.get() trio_test_contextvar.set("back_in_main") back_current_value = trio_test_contextvar.get() @@ -761,14 +784,14 @@ def back_in_main(): async def test_trio_from_thread_run_contextvars() -> None: trio_test_contextvar.set("main") - def thread_fn(): + def thread_fn() -> tuple[str, str, str, str, str]: thread_parent_value = trio_test_contextvar.get() trio_test_contextvar.set("worker") thread_current_value = trio_test_contextvar.get() with pytest.raises(sniffio.AsyncLibraryNotFoundError): sniffio.current_async_library() - async def async_back_in_main(): + async def async_back_in_main() -> tuple[str, str]: back_parent_value = trio_test_contextvar.get() trio_test_contextvar.set("back_in_main") back_current_value = trio_test_contextvar.get() @@ -823,7 +846,7 @@ def test_from_thread_run_during_shutdown() -> None: save = [] record = [] - async def agen(): + async def agen() -> AsyncGenerator[None, None]: try: yield finally: From e8f01124e2d9326b33ba243b5c4bee1f5f290019 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 18 Oct 2023 16:01:54 +0200 Subject: [PATCH 12/35] finish typing test_guest_mode --- pyproject.toml | 1 - trio/_core/_tests/test_guest_mode.py | 12 ++++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5e3f301d63..9edd6d808b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,6 @@ disallow_untyped_calls = false [[tool.mypy.overrides]] module = [ # tests -"trio/_core/_tests/test_guest_mode", "trio/_core/_tests/test_ki", "trio/_core/_tests/test_multierror", ] diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py index 286e8657a1..3358f63a46 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -12,7 +12,15 @@ import warnings from functools import partial from math import inf -from typing import TYPE_CHECKING, Any, Awaitable, Callable, NoReturn, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Awaitable, + Callable, + NoReturn, + TypeVar, +) import pytest from outcome import Outcome @@ -613,7 +621,7 @@ def test_guest_mode_asyncgens() -> None: record = set() - async def agen(label: str): # TODO: some asyncgenerator thing + async def agen(label: str) -> AsyncGenerator[int, None]: assert sniffio.current_async_library() == label try: yield 1 From 7ef03d9d1edf35b5242e007a042272edec86e3d4 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 18 Oct 2023 16:07:02 +0200 Subject: [PATCH 13/35] type test_ki --- pyproject.toml | 2 +- trio/_core/_tests/test_ki.py | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9edd6d808b..949fde14fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,7 @@ disallow_untyped_calls = false [[tool.mypy.overrides]] module = [ # tests -"trio/_core/_tests/test_ki", +#"trio/_core/_tests/test_ki", "trio/_core/_tests/test_multierror", ] disallow_any_decorated = false diff --git a/trio/_core/_tests/test_ki.py b/trio/_core/_tests/test_ki.py index 60a6b17336..bc4a0192af 100644 --- a/trio/_core/_tests/test_ki.py +++ b/trio/_core/_tests/test_ki.py @@ -4,7 +4,7 @@ import inspect import signal import threading -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator import outcome import pytest @@ -77,7 +77,7 @@ async def aunprotected() -> None: nursery.start_soon(aunprotected) @_core.enable_ki_protection - def gen_protected(): + def gen_protected() -> Iterator[None]: assert _core.currently_ki_protected() yield @@ -85,7 +85,7 @@ def gen_protected(): pass @_core.disable_ki_protection - def gen_unprotected(): + def gen_unprotected() -> Iterator[None]: assert not _core.currently_ki_protected() yield @@ -111,7 +111,7 @@ async def protected() -> None: async def unprotected() -> None: await child(False) - async def child(expected) -> None: + async def child(expected: bool) -> None: import traceback traceback.print_stack() @@ -126,10 +126,10 @@ async def child(expected) -> None: # This also used to be broken due to # https://bugs.python.org/issue29590 -async def test_generator_based_context_manager_throw(): +async def test_generator_based_context_manager_throw() -> None: @contextlib.contextmanager @_core.enable_ki_protection - def protected_manager(): + def protected_manager() -> Iterator[None]: assert _core.currently_ki_protected() try: yield @@ -193,7 +193,7 @@ async def agen_unprotected2() -> None: async def test_native_agen_protection() -> None: # Native async generators @_core.enable_ki_protection - async def agen_protected(): + async def agen_protected() -> AsyncIterator[None]: assert _core.currently_ki_protected() try: yield @@ -201,7 +201,7 @@ async def agen_protected(): assert _core.currently_ki_protected() @_core.disable_ki_protection - async def agen_unprotected(): + async def agen_unprotected() -> AsyncIterator[None]: assert not _core.currently_ki_protected() try: yield @@ -212,7 +212,7 @@ async def agen_unprotected(): await _check_agen(agen_unprotected) -async def _check_agen(agen_fn): +async def _check_agen(agen_fn: Callable[[], AsyncIterator[None]]) -> None: async for _ in agen_fn(): assert not _core.currently_ki_protected() @@ -235,7 +235,7 @@ def test_ki_disabled_out_of_context() -> None: def test_ki_disabled_in_del() -> None: - def nestedfunction(): + def nestedfunction() -> bool: return _core.currently_ki_protected() def __del__() -> None: @@ -254,14 +254,14 @@ def outerfunction() -> None: def test_ki_protection_works() -> None: - async def sleeper(name: str, record) -> None: + async def sleeper(name: str, record: set[str]) -> None: try: while True: await _core.checkpoint() except _core.Cancelled: record.add(name + " ok") - async def raiser(name: str, record): + async def raiser(name: str, record: set[str]) -> None: try: # os.kill runs signal handlers before returning, so we don't need # to worry that the handler will be delayed @@ -475,7 +475,7 @@ def test_ki_is_good_neighbor() -> None: try: orig = signal.getsignal(signal.SIGINT) - def my_handler(signum, frame) -> None: # pragma: no cover + def my_handler(signum: object, frame: object) -> None: # pragma: no cover pass async def main() -> None: From c7ff2e14ebd69b0b8d951171da8a9b5a6aea9e50 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 18 Oct 2023 16:24:15 +0200 Subject: [PATCH 14/35] type test_multierror --- pyproject.toml | 15 ++------ trio/_core/_tests/test_multierror.py | 54 +++++++++++++++------------- trio/_tests/test_subprocess.py | 1 - 3 files changed, 32 insertions(+), 38 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 949fde14fd..75849ec9bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,26 +70,15 @@ warn_return_any = true # Avoid subtle backsliding disallow_any_decorated = true disallow_any_generics = true -disallow_any_unimported = false # Enable once Outcome has stubs. disallow_incomplete_defs = true disallow_subclassing_any = true disallow_untyped_decorators = true disallow_untyped_defs = true +check_untyped_defs = true # Enable once other problems are dealt with -check_untyped_defs = true disallow_untyped_calls = false - -# files not yet fully typed -[[tool.mypy.overrides]] -module = [ -# tests -#"trio/_core/_tests/test_ki", -"trio/_core/_tests/test_multierror", -] -disallow_any_decorated = false -disallow_incomplete_defs = false -disallow_untyped_defs = false +disallow_any_unimported = false # Enable once Outcome has stubs. [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config"] diff --git a/trio/_core/_tests/test_multierror.py b/trio/_core/_tests/test_multierror.py index 60ac41ae14..f2682cdf9f 100644 --- a/trio/_core/_tests/test_multierror.py +++ b/trio/_core/_tests/test_multierror.py @@ -9,7 +9,8 @@ import warnings from pathlib import Path from traceback import extract_tb, print_exception -from typing import List +from types import TracebackType +from typing import Callable, List, NoReturn import pytest @@ -35,42 +36,43 @@ def __eq__(self, other: object) -> bool: return self.code == other.code -async def raise_nothashable(code): +async def raise_nothashable(code: int) -> NoReturn: raise NotHashableException(code) -def raiser1() -> None: +def raiser1() -> NoReturn: raiser1_2() -def raiser1_2() -> None: +def raiser1_2() -> NoReturn: raiser1_3() -def raiser1_3(): +def raiser1_3() -> NoReturn: raise ValueError("raiser1_string") -def raiser2() -> None: +def raiser2() -> NoReturn: raiser2_2() -def raiser2_2(): +def raiser2_2() -> NoReturn: raise KeyError("raiser2_string") -def raiser3(): +def raiser3() -> NoReturn: raise NameError -def get_exc(raiser): +def get_exc(raiser: Callable[[], NoReturn]) -> BaseException: try: raiser() except Exception as exc: return exc + raise AssertionError("raiser should always raise") -def get_tb(raiser): +def get_tb(raiser: Callable[[], NoReturn]) -> TracebackType | None: return get_exc(raiser).__traceback__ @@ -141,7 +143,7 @@ async def test_MultiErrorNotHashable() -> None: def test_MultiError_filter_NotHashable() -> None: excs = MultiError([NotHashableException(42), ValueError()]) - def handle_ValueError(exc): + def handle_ValueError(exc: BaseException) -> BaseException | None: if isinstance(exc, ValueError): return None else: @@ -153,7 +155,7 @@ def handle_ValueError(exc): assert isinstance(filtered_excs, NotHashableException) -def make_tree(): +def make_tree() -> MultiError: # Returns an object like: # MultiError([ # MultiError([ @@ -174,7 +176,9 @@ def make_tree(): return MultiError([m12, exc3]) -def assert_tree_eq(m1, m2) -> None: +def assert_tree_eq( + m1: BaseException | MultiError | None, m2: BaseException | MultiError | None +) -> None: if m1 is None or m2 is None: assert m1 is m2 return @@ -183,13 +187,14 @@ def assert_tree_eq(m1, m2) -> None: assert_tree_eq(m1.__cause__, m2.__cause__) assert_tree_eq(m1.__context__, m2.__context__) if isinstance(m1, MultiError): + assert isinstance(m2, MultiError) assert len(m1.exceptions) == len(m2.exceptions) for e1, e2 in zip(m1.exceptions, m2.exceptions): assert_tree_eq(e1, e2) -def test_MultiError_filter(): - def null_handler(exc): +def test_MultiError_filter() -> None: + def null_handler(exc: BaseException) -> BaseException: return exc m = make_tree() @@ -209,7 +214,7 @@ def null_handler(exc): assert MultiError.filter(null_handler, m) is m assert_tree_eq(m, make_tree()) - def simple_filter(exc): + def simple_filter(exc: BaseException) -> BaseException | None: if isinstance(exc, ValueError): return None if isinstance(exc, KeyError): @@ -233,6 +238,7 @@ def simple_filter(exc): # traceback on its parent MultiError orig = make_tree() # make sure we have the right path + assert isinstance(orig.exceptions[0], MultiError) assert isinstance(orig.exceptions[0].exceptions[1], KeyError) # get original traceback summary orig_extracted = ( @@ -241,7 +247,7 @@ def simple_filter(exc): + extract_tb(orig.exceptions[0].exceptions[1].__traceback__) ) - def p(exc) -> None: + def p(exc: BaseException) -> None: print_exception(type(exc), exc, exc.__traceback__) p(orig) @@ -254,7 +260,7 @@ def p(exc) -> None: assert orig_extracted == new_extracted # check preserving partial tree - def filter_NameError(exc): + def filter_NameError(exc: BaseException) -> BaseException | None: if isinstance(exc, NameError): return None return exc @@ -266,17 +272,17 @@ def filter_NameError(exc): assert new_m is m.exceptions[0] # check fully handling everything - def filter_all(exc): + def filter_all(exc: BaseException) -> None: return None with pytest.warns(TrioDeprecationWarning): assert MultiError.filter(filter_all, make_tree()) is None -def test_MultiError_catch(): +def test_MultiError_catch() -> None: # No exception to catch - def noop(_) -> None: + def noop(_: object) -> None: pass # pragma: no cover with pytest.warns(TrioDeprecationWarning), MultiError.catch(noop): @@ -363,16 +369,16 @@ def catch_RuntimeError(exc): @pytest.mark.skipif( sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" ) -def test_MultiError_catch_doesnt_create_cyclic_garbage(): +def test_MultiError_catch_doesnt_create_cyclic_garbage() -> None: # https://github.com/python-trio/trio/pull/2063 gc.collect() old_flags = gc.get_debug() - def make_multi(): + def make_multi() -> NoReturn: # make_tree creates cycles itself, so a simple raise MultiError([get_exc(raiser1), get_exc(raiser2)]) - def simple_filter(exc): + def simple_filter(exc: BaseException) -> Exception | RuntimeError: if isinstance(exc, ValueError): return Exception() if isinstance(exc, KeyError): diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index 3da8e48d63..e4713b3950 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -215,7 +215,6 @@ async def feed_input() -> None: await proc.stdin.aclose() async def check_output(stream: Stream, expected: bytes) -> None: - assert type(stream) is None seen = bytearray() async for chunk in stream: seen += chunk From e873df1eb3dc8a670852d9409aea16abc5606c11 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 18 Oct 2023 17:35:30 +0200 Subject: [PATCH 15/35] fix some type errors I missed outside of _tests/ - testing/_fake_net.py and missing type parameters for SSLStream --- trio/_highlevel_ssl_helpers.py | 17 ++- trio/_ssl.py | 6 +- trio/_tests/test_highlevel_ssl_helpers.py | 4 +- trio/_tests/test_ssl.py | 2 +- trio/testing/_fake_net.py | 133 +++++++++++++++------- 5 files changed, 112 insertions(+), 50 deletions(-) diff --git a/trio/_highlevel_ssl_helpers.py b/trio/_highlevel_ssl_helpers.py index 1647f373c2..c03919d6c0 100644 --- a/trio/_highlevel_ssl_helpers.py +++ b/trio/_highlevel_ssl_helpers.py @@ -2,11 +2,16 @@ import ssl from collections.abc import Awaitable, Callable -from typing import NoReturn +from typing import NoReturn, TypeVar import trio from ._highlevel_open_tcp_stream import DEFAULT_DELAY +from ._highlevel_socket import SocketStream +from .abc import Stream + +T = TypeVar("T") +T_Stream = TypeVar("T_Stream", bound=Stream) # It might have been nice to take a ssl_protocols= argument here to set up @@ -25,7 +30,7 @@ async def open_ssl_over_tcp_stream( https_compatible: bool = False, ssl_context: ssl.SSLContext | None = None, happy_eyeballs_delay: float | None = DEFAULT_DELAY, -) -> trio.SSLStream: +) -> trio.SSLStream[SocketStream]: """Make a TLS-encrypted Connection to the given host and port over TCP. This is a convenience wrapper that calls :func:`open_tcp_stream` and @@ -73,7 +78,7 @@ async def open_ssl_over_tcp_listeners( host: str | bytes | None = None, https_compatible: bool = False, backlog: int | float | None = None, -) -> list[trio.SSLListener]: +) -> list[trio.SSLListener[SocketStream]]: """Start listening for SSL/TLS-encrypted TCP connections to the given port. Args: @@ -95,7 +100,7 @@ async def open_ssl_over_tcp_listeners( async def serve_ssl_over_tcp( - handler: Callable[[trio.SSLStream], Awaitable[object]], + handler: Callable[[trio.SSLStream[T_Stream]], Awaitable[object]], port: int, ssl_context: ssl.SSLContext, *, @@ -103,7 +108,9 @@ async def serve_ssl_over_tcp( https_compatible: bool = False, backlog: int | float | None = None, handler_nursery: trio.Nursery | None = None, - task_status: trio.TaskStatus[list[trio.SSLListener]] = trio.TASK_STATUS_IGNORED, + task_status: trio.TaskStatus[ + list[trio.SSLListener[T_Stream]] + ] = trio.TASK_STATUS_IGNORED, ) -> NoReturn: """Listen for incoming TCP connections, and for each one start a task running ``handler(stream)``. diff --git a/trio/_ssl.py b/trio/_ssl.py index 073d7e3812..77d3b80140 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -893,7 +893,7 @@ async def wait_send_all_might_not_block(self) -> None: @final -class SSLListener(Listener[SSLStream]): +class SSLListener(Listener[SSLStream[T_Stream]]): """A :class:`~trio.abc.Listener` for SSL/TLS-encrypted servers. :class:`SSLListener` wraps around another Listener, and converts @@ -917,7 +917,7 @@ class SSLListener(Listener[SSLStream]): def __init__( self, - transport_listener: Listener[Stream], + transport_listener: Listener[T_Stream], ssl_context: _stdlib_ssl.SSLContext, *, https_compatible: bool = False, @@ -926,7 +926,7 @@ def __init__( self._ssl_context = ssl_context self._https_compatible = https_compatible - async def accept(self) -> SSLStream: + async def accept(self) -> SSLStream[T_Stream]: """Accept the next connection and wrap it in an :class:`SSLStream`. See :meth:`trio.abc.Listener.accept` for details. diff --git a/trio/_tests/test_highlevel_ssl_helpers.py b/trio/_tests/test_highlevel_ssl_helpers.py index 5a06279204..c1b0febbd5 100644 --- a/trio/_tests/test_highlevel_ssl_helpers.py +++ b/trio/_tests/test_highlevel_ssl_helpers.py @@ -74,13 +74,13 @@ async def test_open_ssl_over_tcp_stream_and_everything_else( async with trio.open_nursery() as nursery: # TODO: the types are *very* funky here, this seems like an error in some signature # unless this is doing stuff we don't want/expect end users to do - res: list[SSLListener] = await nursery.start( + res: list[SSLListener[SocketListener]] = await nursery.start( partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1") ) (listener,) = res async with listener: # listener.transport_listener is of type Listener[Stream] - tp_listener: SocketListener = listener.transport_listener # type: ignore[assignment] + tp_listener: SocketListener = listener.transport_listener sockaddr = tp_listener.socket.getsockname() hostname_resolver = FakeHostnameResolver(sockaddr) diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py index 5dece46e54..58e069f239 100644 --- a/trio/_tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -1295,7 +1295,7 @@ async def test_getpeercert(client_ctx: SSLContext) -> None: async def test_SSLListener(client_ctx: SSLContext) -> None: async def setup( **kwargs: Any, - ) -> tuple[tsocket.SocketType, SSLListener, SSLStream[SocketStream]]: + ) -> tuple[tsocket.SocketType, SSLListener[SocketStream], SSLStream[SocketStream]]: listen_sock = tsocket.socket() await listen_sock.bind(("127.0.0.1", 0)) listen_sock.listen(1) diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index 74ce32d37f..2b1d2c8b34 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -12,9 +12,19 @@ import errno import ipaddress import os -from typing import TYPE_CHECKING, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + NoReturn, + Optional, + TypeVar, + Union, + overload, +) import attr +from typing_extensions import Buffer, Self import trio from trio._util import NoPublicConstructor, final @@ -54,11 +64,11 @@ def _localhost_ip_for(family: int) -> IPAddress: assert False -def _fake_err(code): +def _fake_err(code: int) -> NoReturn: raise OSError(code, os.strerror(code)) -def _scatter(data, buffers): +def _scatter(data: bytes, buffers: Iterable[bytes]) -> int: written = 0 for buf in buffers: next_piece = data[written : written + len(buf)] @@ -70,19 +80,27 @@ def _scatter(data, buffers): return written +T_UDPEndpoint = TypeVar("T_UDPEndpoint", bound="UDPEndpoint") + + @attr.frozen class UDPEndpoint: ip: IPAddress port: int - def as_python_sockaddr(self): - sockaddr = (self.ip.compressed, self.port) + def as_python_sockaddr(self) -> tuple[str, int] | tuple[str, int, int, int]: + sockaddr: tuple[str, int] | tuple[str, int, int, int] = ( + self.ip.compressed, + self.port, + ) if isinstance(self.ip, ipaddress.IPv6Address): - sockaddr += (0, 0) + sockaddr += (0, 0) # type: ignore[assignment] return sockaddr @classmethod - def from_python_sockaddr(cls, sockaddr): + def from_python_sockaddr( + cls: type[T_UDPEndpoint], sockaddr: tuple[str, int] | tuple[str, int, int, int] + ) -> T_UDPEndpoint: ip, port = sockaddr[:2] return cls(ip=ipaddress.ip_address(ip), port=port) @@ -98,7 +116,7 @@ class UDPPacket: destination: UDPEndpoint payload: bytes = attr.ib(repr=lambda p: p.hex()) - def reply(self, payload): + def reply(self, payload: bytes) -> UDPPacket: return UDPPacket( source=self.destination, destination=self.source, payload=payload ) @@ -162,13 +180,13 @@ def enable(self) -> None: trio.socket.set_custom_socket_factory(FakeSocketFactory(self)) trio.socket.set_custom_hostname_resolver(FakeHostnameResolver(self)) - def send_packet(self, packet) -> None: + def send_packet(self, packet: UDPPacket) -> None: if self.route_packet is None: self.deliver_packet(packet) else: self.route_packet(packet) - def deliver_packet(self, packet) -> None: + def deliver_packet(self, packet: UDPPacket) -> None: binding = UDPBinding(local=packet.destination) if binding in self._bound: self._bound[binding]._deliver_packet(packet) @@ -219,11 +237,11 @@ def family(self) -> AddressFamily: def proto(self) -> int: return self._proto - def _check_closed(self): + def _check_closed(self) -> None: if self._closed: _fake_err(errno.EBADF) - def close(self): + def close(self) -> None: # breakpoint() if self._closed: return @@ -232,8 +250,10 @@ def close(self): del self._fake_net._bound[self._binding] self._packet_receiver.close() - async def _resolve_address_nocp(self, address, *, local): - return await trio._socket._resolve_address_nocp( + async def _resolve_address_nocp( + self, address: object, *, local: bool + ) -> tuple[str, int]: + return await trio._socket._resolve_address_nocp( # type: ignore[no-any-return] self.type, self.family, self.proto, @@ -253,7 +273,7 @@ def _deliver_packet(self, packet: UDPPacket) -> None: # Actual IO operation implementations ################################################################ - async def bind(self, addr): + async def bind(self, addr: object) -> None: self._check_closed() if self._binding is not None: _fake_err(errno.EINVAL) @@ -272,10 +292,10 @@ async def bind(self, addr): self._fake_net._bind(binding, self) self._binding = binding - async def connect(self, peer): + async def connect(self, peer: object) -> NoReturn: raise NotImplementedError("FakeNet does not (yet) support connected sockets") - async def sendmsg(self, *args): + async def sendmsg(self, *args: Any) -> int: self._check_closed() ancdata = [] flags = 0 @@ -310,6 +330,7 @@ async def sendmsg(self, *args): payload = b"".join(buffers) + assert self._binding is not None packet = UDPPacket( source=self._binding.local, destination=destination, @@ -320,7 +341,12 @@ async def sendmsg(self, *args): return len(payload) - async def recvmsg_into(self, buffers, ancbufsize=0, flags=0): + async def recvmsg_into( + self, + buffers: Iterable[Buffer], + ancbufsize: int = 0, + flags: int = 0, + ) -> tuple[int, list[tuple[int, int, bytes]], int, Any]: if ancbufsize != 0: raise NotImplementedError("FakeNet doesn't support ancillary data") if flags != 0: @@ -328,7 +354,7 @@ async def recvmsg_into(self, buffers, ancbufsize=0, flags=0): self._check_closed() - ancdata = [] + ancdata: list[tuple[int, int, bytes]] = [] msg_flags = 0 packet = await self._packet_receiver.receive() @@ -342,7 +368,7 @@ async def recvmsg_into(self, buffers, ancbufsize=0, flags=0): # Simple state query stuff ################################################################ - def getsockname(self): + def getsockname(self) -> tuple[str, int] | tuple[str, int, int, int]: self._check_closed() if self._binding is not None: return self._binding.local.as_python_sockaddr() @@ -352,31 +378,56 @@ def getsockname(self): assert self.family == trio.socket.AF_INET6 return ("::", 0) - def getpeername(self): + def getpeername(self) -> None: self._check_closed() if self._binding is not None: if self._binding.remote is not None: return self._binding.remote.as_python_sockaddr() _fake_err(errno.ENOTCONN) - def getsockopt(self, level, item): + @overload + def getsockopt(self, /, level: int, optname: int) -> int: + ... + + @overload + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: + ... + + def getsockopt( + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: self._check_closed() - raise OSError(f"FakeNet doesn't implement getsockopt({level}, {item})") + raise OSError(f"FakeNet doesn't implement getsockopt({level}, {optname})") + + @overload + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: + ... - def setsockopt(self, level, item, value): + @overload + def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: + ... + + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: self._check_closed() - if (level, item) == (trio.socket.IPPROTO_IPV6, trio.socket.IPV6_V6ONLY): + if (level, optname) == (trio.socket.IPPROTO_IPV6, trio.socket.IPV6_V6ONLY): if not value: raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True") - raise OSError(f"FakeNet doesn't implement setsockopt({level}, {item}, ...)") + raise OSError(f"FakeNet doesn't implement setsockopt({level}, {optname}, ...)") ################################################################ # Various boilerplate and trivial stubs ################################################################ - def __enter__(self): + def __enter__(self) -> Self: return self def __exit__( @@ -387,10 +438,10 @@ def __exit__( ) -> None: self.close() - async def send(self, data, flags=0): + async def send(self, data: Buffer, flags: int = 0) -> int: return await self.sendto(data, flags, None) - async def sendto(self, *args): + async def sendto(self, *args: Any) -> int: if len(args) == 2: data, address = args flags = 0 @@ -400,19 +451,21 @@ async def sendto(self, *args): raise TypeError("wrong number of arguments") return await self.sendmsg([data], [], flags, address) - async def recv(self, bufsize, flags=0): + async def recv(self, bufsize: int, flags: int = 0) -> bytes: data, address = await self.recvfrom(bufsize, flags) return data - async def recv_into(self, buf, nbytes=0, flags=0): + async def recv_into(self, buf: Buffer, nbytes: int = 0, flags: int = 0) -> int: got_bytes, address = await self.recvfrom_into(buf, nbytes, flags) return got_bytes - async def recvfrom(self, bufsize, flags=0): + async def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[bytes, Any]: data, ancdata, msg_flags, address = await self.recvmsg(bufsize, flags) return data, address - async def recvfrom_into(self, buf, nbytes=0, flags=0): + async def recvfrom_into( + self, buf: Buffer, nbytes: int = 0, flags: int = 0 + ) -> tuple[int, Any]: if nbytes != 0 and nbytes != len(buf): raise NotImplementedError("partial recvfrom_into") got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into( @@ -420,25 +473,27 @@ async def recvfrom_into(self, buf, nbytes=0, flags=0): ) return got_nbytes, address - async def recvmsg(self, bufsize, ancbufsize=0, flags=0): + async def recvmsg( + self, bufsize: int, ancbufsize: int = 0, flags: int = 0 + ) -> tuple[bytes, list[tuple[int, int, bytes]], int, Any]: buf = bytearray(bufsize) got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into( [buf], ancbufsize, flags ) return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address) - def fileno(self): + def fileno(self) -> int: raise NotImplementedError("can't get fileno() for FakeNet sockets") - def detach(self): + def detach(self) -> int: raise NotImplementedError("can't detach() a FakeNet socket") - def get_inheritable(self): + def get_inheritable(self) -> bool: return False - def set_inheritable(self, inheritable): + def set_inheritable(self, inheritable: bool) -> None: if inheritable: raise NotImplementedError("FakeNet can't make inheritable sockets") - def share(self, process_id): + def share(self, process_id: int) -> bytes: raise NotImplementedError("FakeNet can't share sockets") From ef91fb7d700fec6deb595b1852ab6b00a28419ab Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Wed, 18 Oct 2023 18:19:48 -0500 Subject: [PATCH 16/35] Fix runtime instantation of objects that don't define `__class_getitem__` --- trio/_core/_tests/test_guest_mode.py | 3 ++- trio/_core/_tests/test_thread_cache.py | 8 ++++---- trio/_tests/test_threads.py | 13 +++++++------ 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py index 3358f63a46..d53048b067 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -20,6 +20,7 @@ Callable, NoReturn, TypeVar, + cast, ) import pytest @@ -458,7 +459,7 @@ async def trio_main() -> str: print("trio_main!") to_trio, from_aio = trio.open_memory_channel[int](float("inf")) - from_trio = asyncio.Queue[int]() + from_trio = cast("asyncio.Queue[int]", asyncio.Queue)() aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio)) diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py index 2b8d913948..277b6d6bb5 100644 --- a/trio/_core/_tests/test_thread_cache.py +++ b/trio/_core/_tests/test_thread_cache.py @@ -4,7 +4,7 @@ import time from contextlib import contextmanager from queue import Queue -from typing import Iterator, NoReturn +from typing import Iterator, NoReturn, cast import pytest from outcome import Outcome @@ -16,7 +16,7 @@ def test_thread_cache_basics() -> None: - q = Queue[Outcome]() + q = cast("Queue[Outcome]", Queue)() def fn() -> NoReturn: raise RuntimeError("hi") @@ -41,7 +41,7 @@ def __call__(self) -> int: def __del__(self) -> None: res[0] = True - q = Queue[Outcome]() + q = cast("Queue[Outcome]", Queue)() def deliver(outcome: Outcome) -> None: q.put(outcome) @@ -64,7 +64,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None: # Make sure there are a few threads running, so if we weren't LIFO then we # could grab the wrong one. - q = Queue[Outcome]() + q = cast("Queue[Outcome]", Queue)() COUNT = 5 for _ in range(COUNT): start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result)) diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 77a3bf874c..866448dac8 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -20,6 +20,7 @@ Type, TypeVar, Union, + cast, ) import pytest @@ -346,7 +347,7 @@ async def child(q: stdlib_queue.Queue[None], cancellable: bool) -> None: record.append("exit") record: list[str] = [] - q = stdlib_queue.Queue[None]() + q = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue)() async with _core.open_nursery() as nursery: nursery.start_soon(child, q, True) # Give it a chance to get started. (This is important because @@ -394,8 +395,8 @@ def test_run_in_worker_thread_abandoned( ) -> None: monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01) - q1 = stdlib_queue.Queue[None]() - q2 = stdlib_queue.Queue[threading.Thread]() + q1 = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue)() + q2 = cast("stdlib_queue.Queue[threading.Thread]", stdlib_queue.Queue)() def thread_fn() -> None: q1.get() @@ -918,7 +919,7 @@ def get_tid_then_reenter() -> int: async def test_from_thread_host_cancelled() -> None: - queue = stdlib_queue.Queue[bool]() + queue = cast("stdlib_queue.Queue[bool]", stdlib_queue.Queue)() def sync_check() -> None: from_thread_run_sync(cancel_scope.cancel) @@ -977,7 +978,7 @@ async def async_time_bomb() -> None: async def test_from_thread_check_cancelled() -> None: - q = stdlib_queue.Queue[str]() + q = cast("stdlib_queue.Queue[str]", stdlib_queue.Queue)() async def child(cancellable: bool, scope: CancelScope) -> None: with scope: @@ -1057,7 +1058,7 @@ def f() -> None: # type: ignore[no-redef] # noqa: F811 async def test_from_thread_check_cancelled_raises_in_foreign_threads() -> None: with pytest.raises(RuntimeError): from_thread_check_cancelled() - q = stdlib_queue.Queue[Outcome]() + q = cast("stdlib_queue.Queue[Outcome]", stdlib_queue.Queue)() _core.start_thread_soon(from_thread_check_cancelled, lambda _: q.put(_)) with pytest.raises(RuntimeError): q.get(timeout=1).unwrap() From 97eba76551cc6fe3465f7b45dfb9d15b3a25337f Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Wed, 18 Oct 2023 18:25:36 -0500 Subject: [PATCH 17/35] Missed one of the `Queue`s --- trio/_core/_tests/test_thread_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py index 277b6d6bb5..8f9be0d23f 100644 --- a/trio/_core/_tests/test_thread_cache.py +++ b/trio/_core/_tests/test_thread_cache.py @@ -96,7 +96,7 @@ def test_idle_threads_exit(monkeypatch: MonkeyPatch) -> None: # CPU.) monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001) - q = Queue[threading.Thread]() + q = cast("Queue[threading.Thread]", Queue)() start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread())) seen_thread = q.get() # Since the idle timeout is 0, after sleeping for 1 second, the thread From f3e5a36b4705430870e3532655181a5aa7e237d7 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Wed, 18 Oct 2023 18:35:44 -0500 Subject: [PATCH 18/35] Fix not callable issue I just created with the last fix --- trio/_core/_tests/test_guest_mode.py | 2 +- trio/_core/_tests/test_thread_cache.py | 8 ++++---- trio/_tests/test_threads.py | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py index d53048b067..bd0ef9ce69 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -459,7 +459,7 @@ async def trio_main() -> str: print("trio_main!") to_trio, from_aio = trio.open_memory_channel[int](float("inf")) - from_trio = cast("asyncio.Queue[int]", asyncio.Queue)() + from_trio = cast("asyncio.Queue[int]", asyncio.Queue()) aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio)) diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py index 8f9be0d23f..e5b4902d8d 100644 --- a/trio/_core/_tests/test_thread_cache.py +++ b/trio/_core/_tests/test_thread_cache.py @@ -16,7 +16,7 @@ def test_thread_cache_basics() -> None: - q = cast("Queue[Outcome]", Queue)() + q = cast("Queue[Outcome]", Queue()) def fn() -> NoReturn: raise RuntimeError("hi") @@ -41,7 +41,7 @@ def __call__(self) -> int: def __del__(self) -> None: res[0] = True - q = cast("Queue[Outcome]", Queue)() + q = cast("Queue[Outcome]", Queue()) def deliver(outcome: Outcome) -> None: q.put(outcome) @@ -64,7 +64,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None: # Make sure there are a few threads running, so if we weren't LIFO then we # could grab the wrong one. - q = cast("Queue[Outcome]", Queue)() + q = cast("Queue[Outcome]", Queue()) COUNT = 5 for _ in range(COUNT): start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result)) @@ -96,7 +96,7 @@ def test_idle_threads_exit(monkeypatch: MonkeyPatch) -> None: # CPU.) monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001) - q = cast("Queue[threading.Thread]", Queue)() + q = cast("Queue[threading.Thread]", Queue()) start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread())) seen_thread = q.get() # Since the idle timeout is 0, after sleeping for 1 second, the thread diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 866448dac8..4fa1e4bedf 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -347,7 +347,7 @@ async def child(q: stdlib_queue.Queue[None], cancellable: bool) -> None: record.append("exit") record: list[str] = [] - q = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue)() + q = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue()) async with _core.open_nursery() as nursery: nursery.start_soon(child, q, True) # Give it a chance to get started. (This is important because @@ -395,8 +395,8 @@ def test_run_in_worker_thread_abandoned( ) -> None: monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01) - q1 = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue)() - q2 = cast("stdlib_queue.Queue[threading.Thread]", stdlib_queue.Queue)() + q1 = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue()) + q2 = cast("stdlib_queue.Queue[threading.Thread]", stdlib_queue.Queue()) def thread_fn() -> None: q1.get() @@ -919,7 +919,7 @@ def get_tid_then_reenter() -> int: async def test_from_thread_host_cancelled() -> None: - queue = cast("stdlib_queue.Queue[bool]", stdlib_queue.Queue)() + queue = cast("stdlib_queue.Queue[bool]", stdlib_queue.Queue()) def sync_check() -> None: from_thread_run_sync(cancel_scope.cancel) @@ -978,7 +978,7 @@ async def async_time_bomb() -> None: async def test_from_thread_check_cancelled() -> None: - q = cast("stdlib_queue.Queue[str]", stdlib_queue.Queue)() + q = cast("stdlib_queue.Queue[str]", stdlib_queue.Queue()) async def child(cancellable: bool, scope: CancelScope) -> None: with scope: @@ -1058,7 +1058,7 @@ def f() -> None: # type: ignore[no-redef] # noqa: F811 async def test_from_thread_check_cancelled_raises_in_foreign_threads() -> None: with pytest.raises(RuntimeError): from_thread_check_cancelled() - q = cast("stdlib_queue.Queue[Outcome]", stdlib_queue.Queue)() + q = cast("stdlib_queue.Queue[Outcome]", stdlib_queue.Queue()) _core.start_thread_soon(from_thread_check_cancelled, lambda _: q.put(_)) with pytest.raises(RuntimeError): q.get(timeout=1).unwrap() From 3c508020004ad03b076a256b8c8f2f64fc3966b4 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 19 Oct 2023 16:16:52 +0200 Subject: [PATCH 19/35] fixes after suggestions from teamspen & coolcat <3 --- trio/_core/_local.py | 2 +- trio/_core/_tests/test_guest_mode.py | 7 +++++-- trio/_ssl.py | 2 +- trio/_tests/test_file_io.py | 16 +++++++++------- trio/_tests/test_highlevel_generic.py | 3 ++- trio/_tests/test_ssl.py | 6 ++++-- trio/_tests/test_subprocess.py | 4 +++- trio/_tests/test_sync.py | 6 ++++-- trio/_tests/test_threads.py | 11 ++++------- trio/testing/_fake_net.py | 3 +-- 10 files changed, 34 insertions(+), 26 deletions(-) diff --git a/trio/_core/_local.py b/trio/_core/_local.py index 27252bc78d..f1cbb3e61f 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -17,7 +17,7 @@ class _NoValue: @final -@attr.s(eq=False, hash=False, slots=False) +@attr.s(eq=False, hash=False, slots=True) class RunVarToken(Generic[T], metaclass=NoPublicConstructor): _var: RunVar[T] = attr.ib() previous_value: T | type[_NoValue] = attr.ib(default=_NoValue) diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py index bd0ef9ce69..be9d4cb9c2 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -26,15 +26,18 @@ import pytest from outcome import Outcome from pytest import MonkeyPatch, WarningsRecorder -from typing_extensions import TypeAlias import trio import trio.testing from trio._channel import MemorySendChannel +from trio.abc import Instrument from ..._util import signal_raise from .tutil import buggy_pypy_asyncgens, gc_collect_harder, restore_unraisablehook +if TYPE_CHECKING: + from typing_extensions import TypeAlias + T = TypeVar("T") InHost: TypeAlias = Callable[[object], None] @@ -346,7 +349,7 @@ async def get_woken_by_host_deadline(watb_cscope: trio.CancelScope) -> None: # 'sit_in_wait_all_tasks_blocked', we want the test to # actually end. So in after_io_wait we schedule a second host # call to tear things down. - class InstrumentHelper(trio._abc.Instrument): + class InstrumentHelper(Instrument): def __init__(self) -> None: self.primed = False diff --git a/trio/_ssl.py b/trio/_ssl.py index 77d3b80140..57f41c11d2 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -239,7 +239,7 @@ def done(self) -> bool: _State = _Enum("_State", ["OK", "BROKEN", "CLOSED"]) -# TODO: variance +# invariant T_Stream = TypeVar("T_Stream", bound=Stream) diff --git a/trio/_tests/test_file_io.py b/trio/_tests/test_file_io.py index d438a9fb10..5d617a0661 100644 --- a/trio/_tests/test_file_io.py +++ b/trio/_tests/test_file_io.py @@ -244,13 +244,15 @@ async def test_aclose_cancelled(path: pathlib.Path) -> None: assert f.closed -async def test_detach_rewraps_asynciobase() -> None: - raw = io.BytesIO() - buffered = io.BufferedReader(raw) # type: ignore[arg-type] # ???????????? +async def test_detach_rewraps_asynciobase(tmp_path: pathlib.Path) -> None: + tmp_file = tmp_path / "filename" + tmp_file.touch() + with open(tmp_file, mode="rb", buffering=0) as raw: + buffered = io.BufferedReader(raw) - async_file = trio.wrap_file(buffered) + async_file = trio.wrap_file(buffered) - detached = await async_file.detach() + detached = await async_file.detach() - assert isinstance(detached, AsyncIOWrapper) - assert detached.wrapped is raw + assert isinstance(detached, AsyncIOWrapper) + assert detached.wrapped is raw diff --git a/trio/_tests/test_highlevel_generic.py b/trio/_tests/test_highlevel_generic.py index 4b2008c08c..3e9fc212a8 100644 --- a/trio/_tests/test_highlevel_generic.py +++ b/trio/_tests/test_highlevel_generic.py @@ -27,8 +27,9 @@ async def aclose(self) -> None: class RecordReceiveStream(ReceiveStream): record: list[str | tuple[str, int | None]] = attr.ib(factory=list) - async def receive_some(self, max_bytes: int | None = None) -> None: # type: ignore[override] + async def receive_some(self, max_bytes: int | None = None) -> bytes: self.record.append(("receive_some", max_bytes)) + return b"" async def aclose(self) -> None: self.record.append("aclose") diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py index 58e069f239..4a4ba95ecd 100644 --- a/trio/_tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -8,10 +8,9 @@ from contextlib import asynccontextmanager, contextmanager from functools import partial from ssl import SSLContext -from typing import Any, AsyncIterator, Iterator, NoReturn +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, NoReturn import pytest -from typing_extensions import TypeAlias from trio import StapledStream from trio._core import MockClock @@ -46,6 +45,9 @@ memory_stream_pair, ) +if TYPE_CHECKING: + from typing_extensions import TypeAlias + # We have two different kinds of echo server fixtures we use for testing. The # first is a real server written using the stdlib ssl module and blocking # sockets. It runs in a thread and we talk to it over a real socketpair(), to diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index e4713b3950..8c32f1c49b 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -20,7 +20,6 @@ import pytest from pytest import MonkeyPatch, WarningsRecorder -from typing_extensions import TypeAlias from .. import ( ClosedResourceError, @@ -38,6 +37,9 @@ from ..lowlevel import open_process from ..testing import MockClock, assert_no_checkpoints, wait_all_tasks_blocked +if TYPE_CHECKING: + from typing_extensions import TypeAlias + if sys.platform == "win32": SignalType: TypeAlias = None else: diff --git a/trio/_tests/test_sync.py b/trio/_tests/test_sync.py index 448ead15da..3f9ca88650 100644 --- a/trio/_tests/test_sync.py +++ b/trio/_tests/test_sync.py @@ -1,16 +1,18 @@ from __future__ import annotations import weakref -from typing import Callable, Union +from typing import TYPE_CHECKING, Callable, Union import pytest -from typing_extensions import TypeAlias from .. import _core from .._sync import * from .._timeouts import sleep_forever from ..testing import assert_checkpoints, wait_all_tasks_blocked +if TYPE_CHECKING: + from typing_extensions import TypeAlias + async def test_Event() -> None: e = Event() diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 4fa1e4bedf..77650bf569 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -9,7 +9,6 @@ import weakref from functools import partial from typing import ( - TYPE_CHECKING, AsyncGenerator, Awaitable, Callable, @@ -462,12 +461,10 @@ async def test_run_in_worker_thread_limiter( # Mutating them in-place is OK though (as long as you use proper # locking etc.). class state: - if TYPE_CHECKING: - ran: int - high_water: int - running: int - parked: int - pass + ran: int + high_water: int + running: int + parked: int state.ran = 0 state.high_water = 0 diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index 2b1d2c8b34..fbe2670f46 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -24,7 +24,6 @@ ) import attr -from typing_extensions import Buffer, Self import trio from trio._util import NoPublicConstructor, final @@ -33,7 +32,7 @@ from socket import AddressFamily, SocketKind from types import TracebackType - from typing_extensions import TypeAlias + from typing_extensions import Buffer, Self, TypeAlias IPAddress: TypeAlias = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] From 7f0121ba5dd759b19f6332fdb4ea7b20465622b7 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 20 Oct 2023 16:35:21 +0200 Subject: [PATCH 20/35] fix a few more type errors --- trio/_highlevel_serve_listeners.py | 6 +++--- trio/_highlevel_ssl_helpers.py | 6 ++---- trio/_subprocess_platform/waitid.py | 5 +++++ trio/testing/_fake_net.py | 28 +++++++++++++++++++++------- 4 files changed, 31 insertions(+), 14 deletions(-) diff --git a/trio/_highlevel_serve_listeners.py b/trio/_highlevel_serve_listeners.py index d5c7a3bdad..d949cb8f87 100644 --- a/trio/_highlevel_serve_listeners.py +++ b/trio/_highlevel_serve_listeners.py @@ -66,9 +66,7 @@ async def _serve_one_listener( # https://github.com/python/typing/issues/548 -# It does never return (since _serve_one_listener never completes), but type checkers can't -# understand nurseries. -async def serve_listeners( # type: ignore[misc] +async def serve_listeners( handler: Handler[StreamT], listeners: list[ListenerT], *, @@ -143,3 +141,5 @@ async def serve_listeners( # type: ignore[misc] # but we wait until the end to call started() just in case we get an # error or whatever. task_status.started(listeners) + + raise AssertionError("_serve_one_listener should never complete") diff --git a/trio/_highlevel_ssl_helpers.py b/trio/_highlevel_ssl_helpers.py index c03919d6c0..0006d7a41f 100644 --- a/trio/_highlevel_ssl_helpers.py +++ b/trio/_highlevel_ssl_helpers.py @@ -8,10 +8,8 @@ from ._highlevel_open_tcp_stream import DEFAULT_DELAY from ._highlevel_socket import SocketStream -from .abc import Stream T = TypeVar("T") -T_Stream = TypeVar("T_Stream", bound=Stream) # It might have been nice to take a ssl_protocols= argument here to set up @@ -100,7 +98,7 @@ async def open_ssl_over_tcp_listeners( async def serve_ssl_over_tcp( - handler: Callable[[trio.SSLStream[T_Stream]], Awaitable[object]], + handler: Callable[[trio.SSLStream[SocketStream]], Awaitable[object]], port: int, ssl_context: ssl.SSLContext, *, @@ -109,7 +107,7 @@ async def serve_ssl_over_tcp( backlog: int | float | None = None, handler_nursery: trio.Nursery | None = None, task_status: trio.TaskStatus[ - list[trio.SSLListener[T_Stream]] + list[trio.SSLListener[SocketStream]] ] = trio.TASK_STATUS_IGNORED, ) -> NoReturn: """Listen for incoming TCP connections, and for each one start a task diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py index 7d941747e3..60520901b7 100644 --- a/trio/_subprocess_platform/waitid.py +++ b/trio/_subprocess_platform/waitid.py @@ -10,6 +10,11 @@ assert (sys.platform != "win32" and sys.platform != "darwin") or not TYPE_CHECKING +if TYPE_CHECKING: + + def sync_wait_reapable(pid: int) -> None: + ... + try: from os import waitid diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index fbe2670f46..5f6b0e9892 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -67,10 +67,10 @@ def _fake_err(code: int) -> NoReturn: raise OSError(code, os.strerror(code)) -def _scatter(data: bytes, buffers: Iterable[bytes]) -> int: +def _scatter(data: bytes, buffers: Iterable[Buffer]) -> int: written = 0 for buf in buffers: - next_piece = data[written : written + len(buf)] + next_piece = data[written : written + memoryview(buf).nbytes] with memoryview(buf) as mbuf: mbuf[: len(next_piece)] = next_piece written += len(next_piece) @@ -107,6 +107,7 @@ def from_python_sockaddr( @attr.frozen class UDPBinding: local: UDPEndpoint + # remote: UDPEndpoint # ?? @attr.frozen @@ -299,6 +300,10 @@ async def sendmsg(self, *args: Any) -> int: ancdata = [] flags = 0 address = None + + # This does *not* match up with socket.socket.sendmsg (!!!) + # https://docs.python.org/3/library/socket.html#socket.socket.sendmsg + # they always have (buffers, ancdata, flags, address) if len(args) == 1: (buffers,) = args elif len(args) == 2: @@ -377,10 +382,17 @@ def getsockname(self) -> tuple[str, int] | tuple[str, int, int, int]: assert self.family == trio.socket.AF_INET6 return ("::", 0) - def getpeername(self) -> None: + # TODO: This method is not tested, and seems to make incorrect assumptions. It should maybe raise NotImplementedError. + def getpeername(self) -> tuple[str, int] | tuple[str, int, int, int]: self._check_closed() if self._binding is not None: + assert hasattr( + self._binding, "remote" + ), "This method seems to assume that self._binding has a remote UDPEndpoint" if self._binding.remote is not None: + assert isinstance( + self._binding.remote, UDPEndpoint + ), "Self._binding.remote should be a UDPEndpoint" return self._binding.remote.as_python_sockaddr() _fake_err(errno.ENOTCONN) @@ -416,9 +428,11 @@ def setsockopt( ) -> None: self._check_closed() - if (level, optname) == (trio.socket.IPPROTO_IPV6, trio.socket.IPV6_V6ONLY): - if not value: - raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True") + if (level, optname) == ( + trio.socket.IPPROTO_IPV6, + trio.socket.IPV6_V6ONLY, + ) and not value: + raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True") raise OSError(f"FakeNet doesn't implement setsockopt({level}, {optname}, ...)") @@ -465,7 +479,7 @@ async def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[bytes, Any]: async def recvfrom_into( self, buf: Buffer, nbytes: int = 0, flags: int = 0 ) -> tuple[int, Any]: - if nbytes != 0 and nbytes != len(buf): + if nbytes != 0 and nbytes != memoryview(buf).nbytes: raise NotImplementedError("partial recvfrom_into") got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into( [buf], 0, flags From af9b7752a8e686679a4f18a8affe538389aa7160 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 20 Oct 2023 17:15:50 +0200 Subject: [PATCH 21/35] enable disallow_any_unimported and disallow_untyped_calls, fix generic Outcome's, clean up ugly casts introduced by pyannotate --- pyproject.toml | 6 ++---- trio/_core/_multierror.py | 5 ++--- trio/_core/_tests/test_guest_mode.py | 13 +++++++------ trio/_core/_tests/test_thread_cache.py | 14 +++++++------- trio/_file_io.py | 2 +- trio/_highlevel_ssl_helpers.py | 4 ++-- trio/_path.py | 2 +- trio/_socket.py | 2 +- trio/_ssl.py | 2 +- trio/_tests/test_highlevel_ssl_helpers.py | 15 ++++++++++----- trio/_tests/test_threads.py | 13 ++++++------- trio/_tests/test_timeouts.py | 6 ++---- trio/_tests/test_util.py | 6 +++--- 13 files changed, 45 insertions(+), 45 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 75849ec9bd..dfeab31760 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,16 +70,14 @@ warn_return_any = true # Avoid subtle backsliding disallow_any_decorated = true disallow_any_generics = true +disallow_any_unimported = true disallow_incomplete_defs = true disallow_subclassing_any = true +disallow_untyped_calls = true disallow_untyped_decorators = true disallow_untyped_defs = true check_untyped_defs = true -# Enable once other problems are dealt with -disallow_untyped_calls = false -disallow_any_unimported = false # Enable once Outcome has stubs. - [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config"] faulthandler_timeout = 60 diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index 05eaad33e3..79c1cd0c7c 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -424,9 +424,8 @@ def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackT else: # http://doc.pypy.org/en/latest/objspace-proxies.html def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType: - # Mypy refuses to believe that ProxyOperation can be imported properly - # TODO: will need no-any-unimported if/when that's toggled on - def controller(operation: tputil.ProxyOperation) -> Any | None: + # tputil.ProxyOperation is PyPy-only, but we run mypy on CPython + def controller(operation: tputil.ProxyOperation) -> Any | None: # type: ignore[no-any-unimported] # Rationale for pragma: I looked fairly carefully and tried a few # things, and AFAICT it's not actually possible to get any # 'opname' that isn't __getattr__ or __getattribute__. So there's diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py index be9d4cb9c2..dd8847b448 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -20,7 +20,6 @@ Callable, NoReturn, TypeVar, - cast, ) import pytest @@ -54,7 +53,7 @@ def trivial_guest_run( in_host_after_start: Callable[[], None] | None = None, **start_guest_run_kwargs: Any, ) -> T: - todo: queue.Queue[tuple[str, Outcome]] = queue.Queue() + todo: queue.Queue[tuple[str, Outcome[T] | Callable[..., object]]] = queue.Queue() host_thread = threading.current_thread() @@ -76,7 +75,7 @@ def run_sync_soon_not_threadsafe(fn: Callable[[], object]) -> None: todo.put(("run", crash)) todo.put(("run", fn)) - def done_callback(outcome: Outcome) -> None: + def done_callback(outcome: Outcome[T]) -> None: nonlocal todo todo.put(("unwrap", outcome)) @@ -95,9 +94,11 @@ def done_callback(outcome: Outcome) -> None: while True: op, obj = todo.get() if op == "run": + assert not isinstance(obj, Outcome) obj() elif op == "unwrap": - return obj.unwrap() # type: ignore[no-any-return] + assert isinstance(obj, Outcome) + return obj.unwrap() else: # pragma: no cover assert False finally: @@ -435,7 +436,7 @@ def aiotrio_run( async def aio_main() -> T: trio_done_fut = loop.create_future() - def trio_done_callback(main_outcome: Outcome) -> None: + def trio_done_callback(main_outcome: Outcome[object]) -> None: print(f"trio_fn finished: {main_outcome!r}") trio_done_fut.set_result(main_outcome) @@ -462,7 +463,7 @@ async def trio_main() -> str: print("trio_main!") to_trio, from_aio = trio.open_memory_channel[int](float("inf")) - from_trio = cast("asyncio.Queue[int]", asyncio.Queue()) + from_trio: asyncio.Queue[int] = asyncio.Queue() aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio)) diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py index e5b4902d8d..77fdf46664 100644 --- a/trio/_core/_tests/test_thread_cache.py +++ b/trio/_core/_tests/test_thread_cache.py @@ -4,7 +4,7 @@ import time from contextlib import contextmanager from queue import Queue -from typing import Iterator, NoReturn, cast +from typing import Iterator, NoReturn import pytest from outcome import Outcome @@ -16,12 +16,12 @@ def test_thread_cache_basics() -> None: - q = cast("Queue[Outcome]", Queue()) + q: Queue[Outcome[object]] = Queue() def fn() -> NoReturn: raise RuntimeError("hi") - def deliver(outcome: Outcome) -> None: + def deliver(outcome: Outcome[object]) -> None: q.put(outcome) start_thread_soon(fn, deliver) @@ -41,9 +41,9 @@ def __call__(self) -> int: def __del__(self) -> None: res[0] = True - q = cast("Queue[Outcome]", Queue()) + q: Queue[Outcome[int]] = Queue() - def deliver(outcome: Outcome) -> None: + def deliver(outcome: Outcome[int]) -> None: q.put(outcome) start_thread_soon(del_me(), deliver) @@ -64,7 +64,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None: # Make sure there are a few threads running, so if we weren't LIFO then we # could grab the wrong one. - q = cast("Queue[Outcome]", Queue()) + q: Queue[Outcome[object]] = Queue() COUNT = 5 for _ in range(COUNT): start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result)) @@ -96,7 +96,7 @@ def test_idle_threads_exit(monkeypatch: MonkeyPatch) -> None: # CPU.) monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001) - q = cast("Queue[threading.Thread]", Queue()) + q: Queue[threading.Thread] = Queue() start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread())) seen_thread = q.get() # Since the idle timeout is 0, after sleeping for 1 second, the thread diff --git a/trio/_file_io.py b/trio/_file_io.py index 6f30f89dd8..4a5650b453 100644 --- a/trio/_file_io.py +++ b/trio/_file_io.py @@ -430,7 +430,7 @@ async def open_file( @overload -async def open_file( # type: ignore[misc] # Any usage matches builtins.open(). +async def open_file( file: _OpenFile, mode: str, buffering: int = -1, diff --git a/trio/_highlevel_ssl_helpers.py b/trio/_highlevel_ssl_helpers.py index 0006d7a41f..321119216f 100644 --- a/trio/_highlevel_ssl_helpers.py +++ b/trio/_highlevel_ssl_helpers.py @@ -22,7 +22,7 @@ # So... let's punt on that for now. Hopefully we'll be getting a new Python # TLS API soon and can revisit this then. async def open_ssl_over_tcp_stream( - host: str | bytes, + host: str, port: int, *, https_compatible: bool = False, @@ -40,7 +40,7 @@ async def open_ssl_over_tcp_stream( data. Args: - host (bytes or str): The host to connect to. We require the server + host (str): The host to connect to. We require the server to have a TLS certificate valid for this hostname. port (int): The port to connect to. https_compatible (bool): Set this to True if you're connecting to a web diff --git a/trio/_path.py b/trio/_path.py index 219f706825..fb01420a75 100644 --- a/trio/_path.py +++ b/trio/_path.py @@ -311,7 +311,7 @@ async def open( ... @overload - async def open( # type: ignore[misc] # Any usage matches builtins.open(). + async def open( self, mode: str, buffering: int = -1, diff --git a/trio/_socket.py b/trio/_socket.py index 7f6d9b3581..bce0fcf9b3 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -1188,7 +1188,7 @@ async def sendto(self, *args: Any) -> int: # We don't care about invalid types, sendto() will do the checking. return await self._nonblocking_helper( _core.wait_writable, - _stdlib_socket.socket.sendto, # type: ignore[arg-type] + _stdlib_socket.socket.sendto, *args_list, ) diff --git a/trio/_ssl.py b/trio/_ssl.py index 78123c9e68..9a9b608f9e 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -345,7 +345,7 @@ def __init__( transport_stream: T_Stream, ssl_context: _stdlib_ssl.SSLContext, *, - server_hostname: str | bytes | None = None, + server_hostname: str | None = None, server_side: bool = False, https_compatible: bool = False, ) -> None: diff --git a/trio/_tests/test_highlevel_ssl_helpers.py b/trio/_tests/test_highlevel_ssl_helpers.py index c1b0febbd5..8e90adb3d2 100644 --- a/trio/_tests/test_highlevel_ssl_helpers.py +++ b/trio/_tests/test_highlevel_ssl_helpers.py @@ -72,15 +72,20 @@ async def test_open_ssl_over_tcp_stream_and_everything_else( client_ctx: SSLContext, # noqa: F811 # linters doesn't understand fixture ) -> None: async with trio.open_nursery() as nursery: - # TODO: the types are *very* funky here, this seems like an error in some signature - # unless this is doing stuff we don't want/expect end users to do - res: list[SSLListener[SocketListener]] = await nursery.start( - partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1") + # TODO: this function wraps an SSLListener around a SocketListener, this is illegal + # according to current type hints, and probably for good reason. But there should + # maybe be a different wrapper class/function that could be used instead? + res: list[SSLListener[SocketListener]] = ( # type: ignore[type-var] + await nursery.start( + partial( + serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1" + ) + ) ) (listener,) = res async with listener: # listener.transport_listener is of type Listener[Stream] - tp_listener: SocketListener = listener.transport_listener + tp_listener: SocketListener = listener.transport_listener # type: ignore[assignment] sockaddr = tp_listener.socket.getsockname() hostname_resolver = FakeHostnameResolver(sockaddr) diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 77650bf569..327c35a4d9 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -19,7 +19,6 @@ Type, TypeVar, Union, - cast, ) import pytest @@ -346,7 +345,7 @@ async def child(q: stdlib_queue.Queue[None], cancellable: bool) -> None: record.append("exit") record: list[str] = [] - q = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue()) + q: stdlib_queue.Queue[None] = stdlib_queue.Queue() async with _core.open_nursery() as nursery: nursery.start_soon(child, q, True) # Give it a chance to get started. (This is important because @@ -394,8 +393,8 @@ def test_run_in_worker_thread_abandoned( ) -> None: monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01) - q1 = cast("stdlib_queue.Queue[None]", stdlib_queue.Queue()) - q2 = cast("stdlib_queue.Queue[threading.Thread]", stdlib_queue.Queue()) + q1: stdlib_queue.Queue[None] = stdlib_queue.Queue() + q2: stdlib_queue.Queue[threading.Thread] = stdlib_queue.Queue() def thread_fn() -> None: q1.get() @@ -916,7 +915,7 @@ def get_tid_then_reenter() -> int: async def test_from_thread_host_cancelled() -> None: - queue = cast("stdlib_queue.Queue[bool]", stdlib_queue.Queue()) + queue: stdlib_queue.Queue[bool] = stdlib_queue.Queue() def sync_check() -> None: from_thread_run_sync(cancel_scope.cancel) @@ -975,7 +974,7 @@ async def async_time_bomb() -> None: async def test_from_thread_check_cancelled() -> None: - q = cast("stdlib_queue.Queue[str]", stdlib_queue.Queue()) + q: stdlib_queue.Queue[str] = stdlib_queue.Queue() async def child(cancellable: bool, scope: CancelScope) -> None: with scope: @@ -1055,7 +1054,7 @@ def f() -> None: # type: ignore[no-redef] # noqa: F811 async def test_from_thread_check_cancelled_raises_in_foreign_threads() -> None: with pytest.raises(RuntimeError): from_thread_check_cancelled() - q = cast("stdlib_queue.Queue[Outcome]", stdlib_queue.Queue()) + q: stdlib_queue.Queue[Outcome[object]] = stdlib_queue.Queue() _core.start_thread_soon(from_thread_check_cancelled, lambda _: q.put(_)) with pytest.raises(RuntimeError): q.get(timeout=1).unwrap() diff --git a/trio/_tests/test_timeouts.py b/trio/_tests/test_timeouts.py index 918e763faa..43fadfdaca 100644 --- a/trio/_tests/test_timeouts.py +++ b/trio/_tests/test_timeouts.py @@ -12,9 +12,7 @@ T = TypeVar("T") -async def check_takes_about( - f: Callable[[], Awaitable[T]], expected_dur: float -) -> Awaitable[T]: +async def check_takes_about(f: Callable[[], Awaitable[T]], expected_dur: float) -> T: start = time.perf_counter() result = await outcome.acapture(f) dur = time.perf_counter() - start @@ -41,7 +39,7 @@ async def check_takes_about( assert (1 - 1e-8) <= (dur / expected_dur) < 1.5 # outcome is not typed - return result.unwrap() # type: ignore[no-any-return] + return result.unwrap() # How long to (attempt to) sleep for when testing. Smaller numbers make the diff --git a/trio/_tests/test_util.py b/trio/_tests/test_util.py index 01a7503e9d..40c2fd11bb 100644 --- a/trio/_tests/test_util.py +++ b/trio/_tests/test_util.py @@ -270,7 +270,7 @@ def test_fixup_module_metadata() -> None: assert mod.SomeClass.method.__module__ == "trio.somemodule" # type: ignore[attr-defined] assert mod.SomeClass.method.__qualname__ == "SomeClass.method" # type: ignore[attr-defined] # Make coverage happy. - non_trio_module.some_func() - mod.some_func() - mod._private() + non_trio_module.some_func() # type: ignore[no-untyped-call] + mod.some_func() # type: ignore[no-untyped-call] + mod._private() # type: ignore[no-untyped-call] mod.SomeClass().method() From 077092b694020570e54ffd84f3bbe06b0c615342 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 20 Oct 2023 17:49:12 +0200 Subject: [PATCH 22/35] readd type: ignore's incorrectly removed, fix type-checking on non-linux --- trio/_file_io.py | 2 +- trio/_path.py | 2 +- trio/_socket.py | 2 +- trio/_subprocess_platform/waitid.py | 4 ++-- trio/_tests/test_exports.py | 3 ++- trio/_tests/test_highlevel_open_unix_stream.py | 4 ++++ trio/_tests/test_subprocess.py | 12 +++++++++--- 7 files changed, 20 insertions(+), 9 deletions(-) diff --git a/trio/_file_io.py b/trio/_file_io.py index 4a5650b453..6f30f89dd8 100644 --- a/trio/_file_io.py +++ b/trio/_file_io.py @@ -430,7 +430,7 @@ async def open_file( @overload -async def open_file( +async def open_file( # type: ignore[misc] # Any usage matches builtins.open(). file: _OpenFile, mode: str, buffering: int = -1, diff --git a/trio/_path.py b/trio/_path.py index fb01420a75..219f706825 100644 --- a/trio/_path.py +++ b/trio/_path.py @@ -311,7 +311,7 @@ async def open( ... @overload - async def open( + async def open( # type: ignore[misc] # Any usage matches builtins.open(). self, mode: str, buffering: int = -1, diff --git a/trio/_socket.py b/trio/_socket.py index bce0fcf9b3..7f6d9b3581 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -1188,7 +1188,7 @@ async def sendto(self, *args: Any) -> int: # We don't care about invalid types, sendto() will do the checking. return await self._nonblocking_helper( _core.wait_writable, - _stdlib_socket.socket.sendto, + _stdlib_socket.socket.sendto, # type: ignore[arg-type] *args_list, ) diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py index 60520901b7..03a464d6a9 100644 --- a/trio/_subprocess_platform/waitid.py +++ b/trio/_subprocess_platform/waitid.py @@ -8,14 +8,14 @@ from .._sync import CapacityLimiter, Event from .._threads import to_thread_run_sync -assert (sys.platform != "win32" and sys.platform != "darwin") or not TYPE_CHECKING - if TYPE_CHECKING: def sync_wait_reapable(pid: int) -> None: ... +assert (sys.platform != "win32" and sys.platform != "darwin") or not TYPE_CHECKING + try: from os import waitid diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py index 47671b74a3..74a87f7115 100644 --- a/trio/_tests/test_exports.py +++ b/trio/_tests/test_exports.py @@ -157,8 +157,9 @@ def no_underscores(symbols: Iterable[str]) -> set[str]: skip_if_optional_else_raise(error) linter = PyLinter() + assert module.__file__ is not None ast = linter.get_ast(module.__file__, modname) - static_names = no_underscores(ast) + static_names = no_underscores(ast) # type: ignore[arg-type] elif tool == "jedi": try: import jedi diff --git a/trio/_tests/test_highlevel_open_unix_stream.py b/trio/_tests/test_highlevel_open_unix_stream.py index 045820fccf..0ff11209a7 100644 --- a/trio/_tests/test_highlevel_open_unix_stream.py +++ b/trio/_tests/test_highlevel_open_unix_stream.py @@ -1,12 +1,16 @@ import os import socket +import sys import tempfile +from typing import TYPE_CHECKING import pytest from trio import Path, open_unix_socket from trio._highlevel_open_unix_stream import close_on_error +assert not TYPE_CHECKING or sys.platform != "win32" + if not hasattr(socket, "AF_UNIX"): pytestmark = pytest.mark.skip("Needs unix socket support") diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index 8c32f1c49b..51389da518 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -79,7 +79,7 @@ def SLEEP(seconds: int) -> list[str]: def got_signal(proc: Process, sig: SignalType) -> bool: - if posix: + if (not TYPE_CHECKING and posix) or sys.platform != "win32": return proc.returncode == -sig else: return proc.returncode != 0 @@ -444,7 +444,8 @@ async def test_stderr_stdout(background_process: BackgroundProcessType) -> None: async def test_errors() -> None: with pytest.raises(TypeError) as excinfo: - await open_process(["ls"], encoding="utf-8") # type: ignore[call-overload] + # call-overload on unix, call-arg on windows + await open_process(["ls"], encoding="utf-8") # type: ignore assert "unbuffered byte streams" in str(excinfo.value) assert "the 'encoding' option is not supported" in str(excinfo.value) @@ -480,13 +481,15 @@ async def test_one_signal( # tries to handle SIGINT during startup. SIGUSR1's default disposition is # to terminate the target process, and Python doesn't try to do anything # clever to handle it. - if posix: + if (not TYPE_CHECKING and posix) or sys.platform != "win32": await test_one_signal(lambda proc: proc.send_signal(SIGUSR1), SIGUSR1) @pytest.mark.skipif(not posix, reason="POSIX specific") @background_process_param async def test_wait_reapable_fails(background_process: BackgroundProcessType) -> None: + if TYPE_CHECKING and sys.platform == "win32": + return old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN) try: # With SIGCHLD disabled, the wait() syscall will wait for the @@ -510,6 +513,9 @@ def test_waitid_eintr() -> None: # ourselves) but the test works on all waitid platforms. from .._subprocess_platform import wait_child_exiting + if TYPE_CHECKING and sys.platform == "win32": + return + if not wait_child_exiting.__module__.endswith("waitid"): pytest.skip("waitid only") from .._subprocess_platform.waitid import sync_wait_reapable From fcdbdbc410189f6f00e729bef45df0b397b5b484 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 23 Oct 2023 16:27:20 +0200 Subject: [PATCH 23/35] fix ruff/mypy issues --- trio/_tests/test_file_io.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trio/_tests/test_file_io.py b/trio/_tests/test_file_io.py index efcfedeb09..85b8324c56 100644 --- a/trio/_tests/test_file_io.py +++ b/trio/_tests/test_file_io.py @@ -83,7 +83,8 @@ def unsupported_attr(self) -> None: # pragma: no cover assert hasattr(async_file.wrapped, "unsupported_attr") with pytest.raises(AttributeError): - async_file.unsupported_attr # noqa: B018 # "useless expression" + # B018 "useless expression" + async_file.unsupported_attr # type: ignore[attr-defined] # noqa: B018 def test_type_stubs_match_lists() -> None: @@ -247,7 +248,8 @@ async def test_aclose_cancelled(path: pathlib.Path) -> None: async def test_detach_rewraps_asynciobase(tmp_path: pathlib.Path) -> None: tmp_file = tmp_path / "filename" tmp_file.touch() - with open(tmp_file, mode="rb", buffering=0) as raw: + # flake8-async does not like opening files in async mode + with open(tmp_file, mode="rb", buffering=0) as raw: # noqa: ASYNC101 buffered = io.BufferedReader(raw) async_file = trio.wrap_file(buffered) From 16631e99d10b552a7541c1f687f37d12534867bc Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 23 Oct 2023 16:35:40 +0200 Subject: [PATCH 24/35] fix B904 error --- trio/_threads.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trio/_threads.py b/trio/_threads.py index 3a8d83a55a..d44e4b68f8 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -409,10 +409,10 @@ def from_thread_check_cancelled() -> None: """ try: raise_cancel = PARENT_TASK_DATA.cancel_register[0] - except AttributeError: + except AttributeError as exc: raise RuntimeError( "this thread wasn't created by Trio, can't check for cancellation" - ) + ) from exc if raise_cancel is not None: raise_cancel() From 208b1a3f14f91b5bc4e5ce8f991c61725042cc50 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 25 Oct 2023 13:51:26 +0200 Subject: [PATCH 25/35] fix after merge + coolcat review --- trio/_tests/test_subprocess.py | 3 ++- trio/_tests/test_threads.py | 2 +- trio/_tests/test_timeouts.py | 1 - 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index 51389da518..7a5fafdddb 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -9,6 +9,7 @@ from functools import partial from pathlib import Path as SyncPath from signal import Signals +from types import FrameType from typing import ( TYPE_CHECKING, Any, @@ -523,7 +524,7 @@ def test_waitid_eintr() -> None: got_alarm = False sleeper = subprocess.Popen(["sleep", "3600"]) - def on_alarm(sig: object, frame: object) -> None: + def on_alarm(sig: int, frame: FrameType | None) -> None: nonlocal got_alarm got_alarm = True sleeper.kill() diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 90755501bc..637a035d63 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -1062,7 +1062,7 @@ async def test_from_thread_check_cancelled_raises_in_foreign_threads() -> None: @slow -async def test_reentry_doesnt_deadlock(): +async def test_reentry_doesnt_deadlock() -> None: # Regression test for issue noticed in GH-2827 # The failure mode is to hang the whole test suite, unfortunately. # XXX consider running this in a subprocess with a timeout, if it comes up again! diff --git a/trio/_tests/test_timeouts.py b/trio/_tests/test_timeouts.py index 43fadfdaca..c6def0bf9e 100644 --- a/trio/_tests/test_timeouts.py +++ b/trio/_tests/test_timeouts.py @@ -38,7 +38,6 @@ async def check_takes_about(f: Callable[[], Awaitable[T]], expected_dur: float) # started as a 1 ULP error at a different dynamic range.) assert (1 - 1e-8) <= (dur / expected_dur) < 1.5 - # outcome is not typed return result.unwrap() From cfbbcea902fd8cd190f4dbb0b03d54a1f8bfb0c8 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 25 Oct 2023 15:21:52 +0200 Subject: [PATCH 26/35] add --show-fixes to ruff in pre-commit, test show_diff_on_failure --- .pre-commit-config.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ca4811ed2c..deda376520 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,6 +19,7 @@ repos: - id: ruff types: [file] types_or: [python, pyi, toml] + args: ["--show-fixes"] - repo: https://github.com/codespell-project/codespell rev: v2.2.6 hooks: @@ -30,3 +31,4 @@ ci: autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate" autoupdate_schedule: weekly submodules: false + show_diff_on_failure: true From 617038449537e7ed6bfd0f377ee6b360fd49a910 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 25 Oct 2023 15:36:08 +0200 Subject: [PATCH 27/35] autofix from pre-commit --- .pre-commit-config.yaml | 1 - trio/_core/_tests/test_multierror.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index deda376520..cfa4347f63 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,4 +31,3 @@ ci: autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate" autoupdate_schedule: weekly submodules: false - show_diff_on_failure: true diff --git a/trio/_core/_tests/test_multierror.py b/trio/_core/_tests/test_multierror.py index 589fd0eea6..83a0489653 100644 --- a/trio/_core/_tests/test_multierror.py +++ b/trio/_core/_tests/test_multierror.py @@ -10,7 +10,7 @@ from pathlib import Path from traceback import extract_tb, print_exception from types import TracebackType -from typing import Callable, List, NoReturn +from typing import Callable, NoReturn import pytest @@ -400,7 +400,7 @@ def simple_filter(exc: BaseException) -> Exception | RuntimeError: gc.garbage.clear() -def assert_match_in_seq(pattern_list: List[str], string: str) -> None: +def assert_match_in_seq(pattern_list: list[str], string: str) -> None: offset = 0 print("looking for pattern matches...") for pattern in pattern_list: From 90d6cab092cbfd3f456fde373abbc623d4733a8c Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 25 Oct 2023 15:57:28 +0200 Subject: [PATCH 28/35] fix RTD build error --- docs/source/conf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index 06b661f126..c56ce12925 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -71,6 +71,8 @@ # "types.FrameType" is more helpful than just "frame" "FrameType": "types.FrameType", "Context": "OpenSSL.SSL.Context", + # SSLListener.accept's return type is seen as trio._ssl.SSLStream + "SSLStream": "trio.SSLStream", } From 95fa49e762773f82ac422fd34534a4621d591623 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 25 Oct 2023 16:05:31 +0200 Subject: [PATCH 29/35] fix most codecov issues --- .coveragerc | 2 +- trio/_core/_tests/test_guest_mode.py | 2 +- trio/_core/_tests/test_multierror.py | 2 +- trio/_highlevel_serve_listeners.py | 2 +- trio/_tests/test_ssl.py | 6 +++--- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.coveragerc b/.coveragerc index 4911012653..604c14775f 100644 --- a/.coveragerc +++ b/.coveragerc @@ -20,7 +20,7 @@ skip_covered = True exclude_lines = pragma: no cover abc.abstractmethod - if TYPE_CHECKING: + if TYPE_CHECKING.*: if _t.TYPE_CHECKING: if t.TYPE_CHECKING: @overload diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py index af80fd8218..175f687125 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -483,7 +483,7 @@ async def trio_main() -> str: aio_task.cancel() return "trio-main-done" - raise AssertionError("should never be reached") + raise AssertionError("should never be reached") # pragma: no cov async def aio_pingpong( from_trio: asyncio.Queue[int], to_trio: MemorySendChannel[int] diff --git a/trio/_core/_tests/test_multierror.py b/trio/_core/_tests/test_multierror.py index 83a0489653..91e5fc73b6 100644 --- a/trio/_core/_tests/test_multierror.py +++ b/trio/_core/_tests/test_multierror.py @@ -69,7 +69,7 @@ def get_exc(raiser: Callable[[], NoReturn]) -> BaseException: raiser() except Exception as exc: return exc - raise AssertionError("raiser should always raise") + raise AssertionError("raiser should always raise") # pragma: no cov def get_tb(raiser: Callable[[], NoReturn]) -> TracebackType | None: diff --git a/trio/_highlevel_serve_listeners.py b/trio/_highlevel_serve_listeners.py index d949cb8f87..0940e102a5 100644 --- a/trio/_highlevel_serve_listeners.py +++ b/trio/_highlevel_serve_listeners.py @@ -142,4 +142,4 @@ async def serve_listeners( # error or whatever. task_status.started(listeners) - raise AssertionError("_serve_one_listener should never complete") + raise AssertionError("_serve_one_listener should never complete") # pragma: no cov diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py index e0df547b39..8e904679e9 100644 --- a/trio/_tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -778,13 +778,13 @@ async def wait_send_all_might_not_block(self) -> None: # define methods that are abstract in Stream async def aclose(self) -> None: - raise AssertionError("Should not get called") + raise AssertionError("Should not get called") # pragma: no cov async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: - raise AssertionError("Should not get called") + raise AssertionError("Should not get called") # pragma: no cov async def send_all(self, data: bytes | bytearray | memoryview) -> None: - raise AssertionError("Should not get called") + raise AssertionError("Should not get called") # pragma: no cov ctx = ssl.create_default_context() s = SSLStream(NotAStream(), ctx, server_hostname="x") From 027279a806ec248cdbe44eccb269f68782477d2d Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 25 Oct 2023 16:41:37 +0200 Subject: [PATCH 30/35] pragma: no cov -> pragma: no cover --- trio/_core/_tests/test_guest_mode.py | 2 +- trio/_core/_tests/test_multierror.py | 2 +- trio/_highlevel_serve_listeners.py | 4 +++- trio/_tests/test_ssl.py | 6 +++--- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py index 175f687125..a648832b4c 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -483,7 +483,7 @@ async def trio_main() -> str: aio_task.cancel() return "trio-main-done" - raise AssertionError("should never be reached") # pragma: no cov + raise AssertionError("should never be reached") # pragma: no cover async def aio_pingpong( from_trio: asyncio.Queue[int], to_trio: MemorySendChannel[int] diff --git a/trio/_core/_tests/test_multierror.py b/trio/_core/_tests/test_multierror.py index 91e5fc73b6..09f9b6a271 100644 --- a/trio/_core/_tests/test_multierror.py +++ b/trio/_core/_tests/test_multierror.py @@ -69,7 +69,7 @@ def get_exc(raiser: Callable[[], NoReturn]) -> BaseException: raiser() except Exception as exc: return exc - raise AssertionError("raiser should always raise") # pragma: no cov + raise AssertionError("raiser should always raise") # pragma: no cover def get_tb(raiser: Callable[[], NoReturn]) -> TracebackType | None: diff --git a/trio/_highlevel_serve_listeners.py b/trio/_highlevel_serve_listeners.py index 0940e102a5..ec5a0efb3c 100644 --- a/trio/_highlevel_serve_listeners.py +++ b/trio/_highlevel_serve_listeners.py @@ -142,4 +142,6 @@ async def serve_listeners( # error or whatever. task_status.started(listeners) - raise AssertionError("_serve_one_listener should never complete") # pragma: no cov + raise AssertionError( + "_serve_one_listener should never complete" + ) # pragma: no cover diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py index 8e904679e9..cd404a1c97 100644 --- a/trio/_tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -778,13 +778,13 @@ async def wait_send_all_might_not_block(self) -> None: # define methods that are abstract in Stream async def aclose(self) -> None: - raise AssertionError("Should not get called") # pragma: no cov + raise AssertionError("Should not get called") # pragma: no cover async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: - raise AssertionError("Should not get called") # pragma: no cov + raise AssertionError("Should not get called") # pragma: no cover async def send_all(self, data: bytes | bytearray | memoryview) -> None: - raise AssertionError("Should not get called") # pragma: no cov + raise AssertionError("Should not get called") # pragma: no cover ctx = ssl.create_default_context() s = SSLStream(NotAStream(), ctx, server_hostname="x") From fc0a6d3b5bd360ac907dcb2bc8254d73b98a3350 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 29 Oct 2023 18:54:06 +0000 Subject: [PATCH 31/35] [pre-commit.ci] auto fixes from pre-commit.com hooks --- trio/_tests/test_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_tests/test_sync.py b/trio/_tests/test_sync.py index 3f9ca88650..9179c8a5ae 100644 --- a/trio/_tests/test_sync.py +++ b/trio/_tests/test_sync.py @@ -247,7 +247,7 @@ async def test_Semaphore_bounded() -> None: @pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__) async def test_Lock_and_StrictFIFOLock( - lockcls: type[Lock] | type[StrictFIFOLock], + lockcls: type[Lock | StrictFIFOLock], ) -> None: l = lockcls() # noqa assert not l.locked() From 81f8bf0b8f0aa2c43844579a6895fb9759c8c27a Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 30 Oct 2023 12:31:21 +0100 Subject: [PATCH 32/35] minor changes after review comments --- trio/_core/_tests/test_instrumentation.py | 5 ++++- trio/_core/_tests/test_ki.py | 1 + trio/_core/_tests/tutil.py | 6 ++++-- trio/_highlevel_ssl_helpers.py | 4 ++-- trio/_ssl.py | 2 +- trio/_subprocess_platform/waitid.py | 6 ------ trio/_tests/test_exports.py | 3 +-- trio/_tests/test_highlevel_serve_listeners.py | 17 ++++++++--------- trio/_tests/test_subprocess.py | 10 ++++++++-- 9 files changed, 29 insertions(+), 25 deletions(-) diff --git a/trio/_core/_tests/test_instrumentation.py b/trio/_core/_tests/test_instrumentation.py index 1c35a17ee6..f743f2b3d4 100644 --- a/trio/_core/_tests/test_instrumentation.py +++ b/trio/_core/_tests/test_instrumentation.py @@ -219,7 +219,10 @@ async def main() -> None: # Changing the set of hooks implemented by an instrument after # it's installed doesn't make them start being called right away - instrument.before_task_step = record.append # type: ignore[assignment] + instrument.before_task_step = ( # type: ignore[method-assign] + record.append # type: ignore[assignment] # append is pos-only + ) + await _core.checkpoint() await _core.checkpoint() assert len(record) == 0 diff --git a/trio/_core/_tests/test_ki.py b/trio/_core/_tests/test_ki.py index bc4a0192af..cd98bc9bca 100644 --- a/trio/_core/_tests/test_ki.py +++ b/trio/_core/_tests/test_ki.py @@ -145,6 +145,7 @@ def protected_manager() -> Iterator[None]: raise KeyError +# the async_generator package isn't typed, hence all the type: ignores @pytest.mark.skipif(async_generator is None, reason="async_generator not installed") async def test_async_generator_agen_protection() -> None: @_core.enable_ki_protection diff --git a/trio/_core/_tests/tutil.py b/trio/_core/_tests/tutil.py index 1d49d8b262..6ed9b5fe14 100644 --- a/trio/_core/_tests/tutil.py +++ b/trio/_core/_tests/tutil.py @@ -98,8 +98,10 @@ def restore_unraisablehook() -> Generator[None, None, None]: sys.unraisablehook = prev -# template is like: -# [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 2.2, 3] or [1, 2.2, 2.1, 3] +# Used to check sequences that might have some elements out of order. +# Example usage: +# The sequences [1, 2.1, 2.2, 3] and [1, 2.2, 2.1, 3] are both +# matched by the template [1, {2.1, 2.2}, 3] def check_sequence_matches(seq: Sequence[T], template: Iterable[T | set[T]]) -> None: i = 0 for pattern in template: diff --git a/trio/_highlevel_ssl_helpers.py b/trio/_highlevel_ssl_helpers.py index e70a8db47d..3215fcf969 100644 --- a/trio/_highlevel_ssl_helpers.py +++ b/trio/_highlevel_ssl_helpers.py @@ -22,7 +22,7 @@ # So... let's punt on that for now. Hopefully we'll be getting a new Python # TLS API soon and can revisit this then. async def open_ssl_over_tcp_stream( - host: str, + host: str | bytes, port: int, *, https_compatible: bool = False, @@ -40,7 +40,7 @@ async def open_ssl_over_tcp_stream( data. Args: - host (str): The host to connect to. We require the server + host (bytes or str): The host to connect to. We require the server to have a TLS certificate valid for this hostname. port (int): The port to connect to. https_compatible (bool): Set this to True if you're connecting to a web diff --git a/trio/_ssl.py b/trio/_ssl.py index 4ef8a15721..d2e1c73da5 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -345,7 +345,7 @@ def __init__( transport_stream: T_Stream, ssl_context: _stdlib_ssl.SSLContext, *, - server_hostname: str | None = None, + server_hostname: str | bytes | None = None, server_side: bool = False, https_compatible: bool = False, ) -> None: diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py index 03a464d6a9..756741218f 100644 --- a/trio/_subprocess_platform/waitid.py +++ b/trio/_subprocess_platform/waitid.py @@ -8,12 +8,6 @@ from .._sync import CapacityLimiter, Event from .._threads import to_thread_run_sync -if TYPE_CHECKING: - - def sync_wait_reapable(pid: int) -> None: - ... - - assert (sys.platform != "win32" and sys.platform != "darwin") or not TYPE_CHECKING try: diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py index ec10edf4eb..7b38137887 100644 --- a/trio/_tests/test_exports.py +++ b/trio/_tests/test_exports.py @@ -2,6 +2,7 @@ import __future__ # Regular import, not special! +import enum import functools import importlib import inspect @@ -446,8 +447,6 @@ def lookup_symbol(symbol: str) -> dict[str, str]: extra = {e for e in extra if not e.endswith("AttrsAttributes__")} assert len(extra) == before - 1 - import enum - # mypy does not see these attributes in Enum subclasses if ( tool == "mypy" diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py index 060ba46ed7..75cadfd8aa 100644 --- a/trio/_tests/test_highlevel_serve_listeners.py +++ b/trio/_tests/test_highlevel_serve_listeners.py @@ -21,25 +21,26 @@ # types are somewhat tentative - I just bruteforced them until I got something that didn't # give errors -TypeThing = StapledStream[MemorySendStream, MemoryReceiveStream] +StapledMemoryStream = StapledStream[MemorySendStream, MemoryReceiveStream] @attr.s(hash=False, eq=False) -class MemoryListener(trio.abc.Listener[TypeThing]): +class MemoryListener(trio.abc.Listener[StapledMemoryStream]): closed: bool = attr.ib(default=False) accepted_streams: list[trio.abc.Stream] = attr.ib(factory=list) queued_streams: tuple[ - MemorySendChannel[TypeThing], MemoryReceiveChannel[TypeThing] - ] = attr.ib(factory=(lambda: trio.open_memory_channel[TypeThing](1))) + MemorySendChannel[StapledMemoryStream], + MemoryReceiveChannel[StapledMemoryStream], + ] = attr.ib(factory=(lambda: trio.open_memory_channel[StapledMemoryStream](1))) accept_hook: Callable[[], Awaitable[object]] | None = attr.ib(default=None) - async def connect(self) -> StapledStream[MemorySendStream, MemoryReceiveStream]: + async def connect(self) -> StapledMemoryStream: assert not self.closed client, server = memory_stream_pair() await self.queued_streams[0].send(server) return client - async def accept(self) -> TypeThing: + async def accept(self) -> StapledMemoryStream: await trio.lowlevel.checkpoint() assert not self.closed if self.accept_hook is not None: @@ -63,9 +64,7 @@ def close_hook() -> None: assert trio.current_effective_deadline() == float("-inf") record.append("closed") - async def handler( - stream: StapledStream[MemorySendStream, MemoryReceiveStream] - ) -> None: + async def handler(stream: StapledMemoryStream) -> None: await stream.send_all(b"123") assert await stream.receive_some(10) == b"456" stream.send_stream.close_hook = close_hook diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index 7a5fafdddb..eb8bfd1c43 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -514,12 +514,18 @@ def test_waitid_eintr() -> None: # ourselves) but the test works on all waitid platforms. from .._subprocess_platform import wait_child_exiting - if TYPE_CHECKING and sys.platform == "win32": + if TYPE_CHECKING and (sys.platform == "win32" or sys.platform == "darwin"): return if not wait_child_exiting.__module__.endswith("waitid"): pytest.skip("waitid only") - from .._subprocess_platform.waitid import sync_wait_reapable + + # despite the TYPE_CHECKING early return silencing warnings about signal.SIGALRM etc + # this import is still checked on win32&darwin and raises [attr-defined]. + # Linux doesn't raise [attr-defined] though, so we need [unused-ignore] + from .._subprocess_platform.waitid import ( + sync_wait_reapable, # type: ignore[attr-defined, unused-ignore] + ) got_alarm = False sleeper = subprocess.Popen(["sleep", "3600"]) From 3935a8269763cad257ba29a0759974be53f2c7ee Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 30 Oct 2023 12:38:50 +0100 Subject: [PATCH 33/35] fix incorrect suppression found by enabling typechecking (!), move type: ignore to the correct line after incorrect autofix by black --- trio/_tests/test_ssl.py | 6 ++---- trio/_tests/test_subprocess.py | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py index cd404a1c97..e0b76d812a 100644 --- a/trio/_tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -5,7 +5,7 @@ import ssl import sys import threading -from contextlib import asynccontextmanager, contextmanager +from contextlib import asynccontextmanager, contextmanager, suppress from functools import partial from ssl import SSLContext from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, NoReturn @@ -116,10 +116,8 @@ def ssl_echo_serve_sync( # respond in kind but it's legal for them to have already # gone away. exceptions = (BrokenPipeError, ssl.SSLZeroReturnError) - try: + with suppress(*exceptions): wrapped.unwrap() - except exceptions: - pass return wrapped.sendall(data) # This is an obscure workaround for an openssl bug. In server mode, in diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index eb8bfd1c43..c901f6f29e 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -231,7 +231,7 @@ async def check_output(stream: Stream, expected: bytes) -> None: nursery.start_soon(check_output, proc.stderr, msg[::-1]) assert not nursery.cancel_scope.cancelled_caught - assert 0 == await proc.wait() + assert await proc.wait() == 0 @background_process_param @@ -523,8 +523,8 @@ def test_waitid_eintr() -> None: # despite the TYPE_CHECKING early return silencing warnings about signal.SIGALRM etc # this import is still checked on win32&darwin and raises [attr-defined]. # Linux doesn't raise [attr-defined] though, so we need [unused-ignore] - from .._subprocess_platform.waitid import ( - sync_wait_reapable, # type: ignore[attr-defined, unused-ignore] + from .._subprocess_platform.waitid import ( # type: ignore[attr-defined, unused-ignore] + sync_wait_reapable, ) got_alarm = False From 4d25e75f5a7f393e0e75741257381d7336fdf496 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 30 Oct 2023 13:08:33 +0100 Subject: [PATCH 34/35] mark lines with # pragma: no cover --- trio/_tests/test_tracing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trio/_tests/test_tracing.py b/trio/_tests/test_tracing.py index 1b9ea02b5f..5cf758c6b6 100644 --- a/trio/_tests/test_tracing.py +++ b/trio/_tests/test_tracing.py @@ -22,9 +22,9 @@ async def coro2_async_gen(event: trio.Event) -> AsyncGenerator[None, None]: await trio.lowlevel.checkpoint() yield await coro1(event) - yield - await trio.lowlevel.checkpoint() - yield + yield # pragma: no cover + await trio.lowlevel.checkpoint() # pragma: no cover + yield # pragma: no cover async def coro3_async_gen(event: trio.Event) -> None: From 43c7d6da9f56e9de33ce33bf463ec8bba4be799c Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 30 Oct 2023 13:14:24 +0100 Subject: [PATCH 35/35] remove redundant tuple --- trio/_tests/test_ssl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/trio/_tests/test_ssl.py b/trio/_tests/test_ssl.py index e0b76d812a..13decd5c72 100644 --- a/trio/_tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -115,8 +115,7 @@ def ssl_echo_serve_sync( # other side has initiated a graceful shutdown; we try to # respond in kind but it's legal for them to have already # gone away. - exceptions = (BrokenPipeError, ssl.SSLZeroReturnError) - with suppress(*exceptions): + with suppress(BrokenPipeError, ssl.SSLZeroReturnError): wrapped.unwrap() return wrapped.sendall(data)