diff --git a/src/_pytest/_code/source.py b/src/_pytest/_code/source.py index db78bbd0d35..1e9dd5031c3 100644 --- a/src/_pytest/_code/source.py +++ b/src/_pytest/_code/source.py @@ -7,10 +7,17 @@ import warnings from ast import PyCF_ONLY_AST as _AST_FLAG from bisect import bisect_right +from types import FrameType from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union import py +from _pytest.compat import overload + class Source: """ an immutable object holding a source code fragment, @@ -19,7 +26,7 @@ class Source: _compilecounter = 0 - def __init__(self, *parts, **kwargs): + def __init__(self, *parts, **kwargs) -> None: self.lines = lines = [] # type: List[str] de = kwargs.get("deindent", True) for part in parts: @@ -48,7 +55,15 @@ def __eq__(self, other): # Ignore type because of https://github.com/python/mypy/issues/4266. __hash__ = None # type: ignore - def __getitem__(self, key): + @overload + def __getitem__(self, key: int) -> str: + raise NotImplementedError() + + @overload # noqa: F811 + def __getitem__(self, key: slice) -> "Source": + raise NotImplementedError() + + def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811 if isinstance(key, int): return self.lines[key] else: @@ -58,10 +73,10 @@ def __getitem__(self, key): newsource.lines = self.lines[key.start : key.stop] return newsource - def __len__(self): + def __len__(self) -> int: return len(self.lines) - def strip(self): + def strip(self) -> "Source": """ return new source object with trailing and leading blank lines removed. """ @@ -74,18 +89,20 @@ def strip(self): source.lines[:] = self.lines[start:end] return source - def putaround(self, before="", after="", indent=" " * 4): + def putaround( + self, before: str = "", after: str = "", indent: str = " " * 4 + ) -> "Source": """ return a copy of the source object with 'before' and 'after' wrapped around it. """ - before = Source(before) - after = Source(after) + beforesource = Source(before) + aftersource = Source(after) newsource = Source() lines = [(indent + line) for line in self.lines] - newsource.lines = before.lines + lines + after.lines + newsource.lines = beforesource.lines + lines + aftersource.lines return newsource - def indent(self, indent=" " * 4): + def indent(self, indent: str = " " * 4) -> "Source": """ return a copy of the source object with all lines indented by the given indent-string. """ @@ -93,14 +110,14 @@ def indent(self, indent=" " * 4): newsource.lines = [(indent + line) for line in self.lines] return newsource - def getstatement(self, lineno): + def getstatement(self, lineno: int) -> "Source": """ return Source statement which contains the given linenumber (counted from 0). """ start, end = self.getstatementrange(lineno) return self[start:end] - def getstatementrange(self, lineno): + def getstatementrange(self, lineno: int): """ return (start, end) tuple which spans the minimal statement region which containing the given lineno. """ @@ -109,13 +126,13 @@ def getstatementrange(self, lineno): ast, start, end = getstatementrange_ast(lineno, self) return start, end - def deindent(self): + def deindent(self) -> "Source": """return a new source object deindented.""" newsource = Source() newsource.lines[:] = deindent(self.lines) return newsource - def isparseable(self, deindent=True): + def isparseable(self, deindent: bool = True) -> bool: """ return True if source is parseable, heuristically deindenting it by default. """ @@ -135,11 +152,16 @@ def isparseable(self, deindent=True): else: return True - def __str__(self): + def __str__(self) -> str: return "\n".join(self.lines) def compile( - self, filename=None, mode="exec", flag=0, dont_inherit=0, _genframe=None + self, + filename=None, + mode="exec", + flag: int = 0, + dont_inherit: int = 0, + _genframe: Optional[FrameType] = None, ): """ return compiled code object. if filename is None invent an artificial filename which displays @@ -183,7 +205,7 @@ def compile( # -def compile_(source, filename=None, mode="exec", flags=0, dont_inherit=0): +def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0): """ compile the given source to a raw code object, and maintain an internal cache which allows later retrieval of the source code for the code object @@ -233,7 +255,7 @@ def getfslineno(obj): # -def findsource(obj): +def findsource(obj) -> Tuple[Optional[Source], int]: try: sourcelines, lineno = inspect.findsource(obj) except Exception: @@ -243,7 +265,7 @@ def findsource(obj): return source, lineno -def getsource(obj, **kwargs): +def getsource(obj, **kwargs) -> Source: from .code import getrawcode obj = getrawcode(obj) @@ -255,21 +277,21 @@ def getsource(obj, **kwargs): return Source(strsrc, **kwargs) -def deindent(lines): +def deindent(lines: Sequence[str]) -> List[str]: return textwrap.dedent("\n".join(lines)).splitlines() -def get_statement_startend2(lineno, node): +def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]: import ast # flatten all statements and except handlers into one lineno-list # AST's line numbers start indexing at 1 - values = [] + values = [] # type: List[int] for x in ast.walk(node): if isinstance(x, (ast.stmt, ast.ExceptHandler)): values.append(x.lineno - 1) for name in ("finalbody", "orelse"): - val = getattr(x, name, None) + val = getattr(x, name, None) # type: Optional[List[ast.stmt]] if val: # treat the finally/orelse part as its own statement values.append(val[0].lineno - 1 - 1) @@ -283,7 +305,12 @@ def get_statement_startend2(lineno, node): return start, end -def getstatementrange_ast(lineno, source: Source, assertion=False, astnode=None): +def getstatementrange_ast( + lineno: int, + source: Source, + assertion: bool = False, + astnode: Optional[ast.AST] = None, +) -> Tuple[ast.AST, int, int]: if astnode is None: content = str(source) # See #4260: diff --git a/src/_pytest/main.py b/src/_pytest/main.py index 5a7858cf0cf..ca577005701 100644 --- a/src/_pytest/main.py +++ b/src/_pytest/main.py @@ -493,7 +493,6 @@ def _perform_collect(self, args, genitems): for arg, exc in self._notfound: line = "(no name {!r} in any of {!r})".format(arg, exc.args[0]) errors.append("not found: {}\n{}".format(arg, line)) - # XXX: test this raise UsageError(*errors) if not genitems: return rep.result diff --git a/src/_pytest/pytester.py b/src/_pytest/pytester.py index de73fa9e245..dad3d72df60 100644 --- a/src/_pytest/pytester.py +++ b/src/_pytest/pytester.py @@ -1,4 +1,5 @@ """(disabled by default) support for testing pytest and pytest plugins.""" +import collections.abc import functools import gc import importlib @@ -10,9 +11,15 @@ import sys import time import traceback -from collections.abc import Sequence from fnmatch import fnmatch from io import StringIO +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple from typing import Union from weakref import WeakKeyDictionary @@ -23,10 +30,16 @@ from _pytest._io.saferepr import saferepr from _pytest.capture import MultiCapture from _pytest.capture import SysCapture +from _pytest.fixtures import FixtureRequest from _pytest.main import ExitCode from _pytest.main import Session from _pytest.monkeypatch import MonkeyPatch from _pytest.pathlib import Path +from _pytest.reports import TestReport + +if False: # TYPE_CHECKING + from typing import Type + IGNORE_PAM = [ # filenames added when obtaining details about the current user "/var/lib/sss/mc/passwd" @@ -144,7 +157,7 @@ def pytest_runtest_protocol(self, item): @pytest.fixture(name="_pytest") -def __pytest(request): +def __pytest(request: FixtureRequest) -> "PytestArg": """Return a helper which offers a gethookrecorder(hook) method which returns a HookRecorder instance which helps to make assertions about called hooks. @@ -154,10 +167,10 @@ def __pytest(request): class PytestArg: - def __init__(self, request): + def __init__(self, request: FixtureRequest) -> None: self.request = request - def gethookrecorder(self, hook): + def gethookrecorder(self, hook) -> "HookRecorder": hookrecorder = HookRecorder(hook._pm) self.request.addfinalizer(hookrecorder.finish_recording) return hookrecorder @@ -178,6 +191,11 @@ def __repr__(self): del d["_name"] return "".format(self._name, d) + if False: # TYPE_CHECKING + # The class has undetermined attributes, this tells mypy about it. + def __getattr__(self, key): + raise NotImplementedError() + class HookRecorder: """Record all hooks called in a plugin manager. @@ -187,27 +205,27 @@ class HookRecorder: """ - def __init__(self, pluginmanager): + def __init__(self, pluginmanager) -> None: self._pluginmanager = pluginmanager - self.calls = [] + self.calls = [] # type: List[ParsedCall] - def before(hook_name, hook_impls, kwargs): + def before(hook_name: str, hook_impls, kwargs) -> None: self.calls.append(ParsedCall(hook_name, kwargs)) - def after(outcome, hook_name, hook_impls, kwargs): + def after(outcome, hook_name: str, hook_impls, kwargs) -> None: pass self._undo_wrapping = pluginmanager.add_hookcall_monitoring(before, after) - def finish_recording(self): + def finish_recording(self) -> None: self._undo_wrapping() - def getcalls(self, names): + def getcalls(self, names: Union[str, Iterable[str]]) -> List[ParsedCall]: if isinstance(names, str): names = names.split() return [call for call in self.calls if call._name in names] - def assert_contains(self, entries): + def assert_contains(self, entries) -> None: __tracebackhide__ = True i = 0 entries = list(entries) @@ -228,7 +246,7 @@ def assert_contains(self, entries): else: pytest.fail("could not find {!r} check {!r}".format(name, check)) - def popcall(self, name): + def popcall(self, name: str) -> ParsedCall: __tracebackhide__ = True for i, call in enumerate(self.calls): if call._name == name: @@ -238,20 +256,27 @@ def popcall(self, name): lines.extend([" %s" % x for x in self.calls]) pytest.fail("\n".join(lines)) - def getcall(self, name): + def getcall(self, name: str) -> ParsedCall: values = self.getcalls(name) assert len(values) == 1, (name, values) return values[0] # functionality for test reports - def getreports(self, names="pytest_runtest_logreport pytest_collectreport"): + def getreports( + self, + names: Union[ + str, Iterable[str] + ] = "pytest_runtest_logreport pytest_collectreport", + ) -> List[TestReport]: return [x.report for x in self.getcalls(names)] def matchreport( self, - inamepart="", - names="pytest_runtest_logreport pytest_collectreport", + inamepart: str = "", + names: Union[ + str, Iterable[str] + ] = "pytest_runtest_logreport pytest_collectreport", when=None, ): """return a testreport whose dotted import path matches""" @@ -277,13 +302,20 @@ def matchreport( ) return values[0] - def getfailures(self, names="pytest_runtest_logreport pytest_collectreport"): + def getfailures( + self, + names: Union[ + str, Iterable[str] + ] = "pytest_runtest_logreport pytest_collectreport", + ) -> List[TestReport]: return [rep for rep in self.getreports(names) if rep.failed] - def getfailedcollections(self): + def getfailedcollections(self) -> List[TestReport]: return self.getfailures("pytest_collectreport") - def listoutcomes(self): + def listoutcomes( + self + ) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]: passed = [] skipped = [] failed = [] @@ -298,31 +330,31 @@ def listoutcomes(self): failed.append(rep) return passed, skipped, failed - def countoutcomes(self): + def countoutcomes(self) -> List[int]: return [len(x) for x in self.listoutcomes()] - def assertoutcome(self, passed=0, skipped=0, failed=0): + def assertoutcome(self, passed: int = 0, skipped: int = 0, failed: int = 0) -> None: realpassed, realskipped, realfailed = self.listoutcomes() assert passed == len(realpassed) assert skipped == len(realskipped) assert failed == len(realfailed) - def clear(self): + def clear(self) -> None: self.calls[:] = [] @pytest.fixture -def linecomp(request): +def linecomp(request: FixtureRequest) -> "LineComp": return LineComp() @pytest.fixture(name="LineMatcher") -def LineMatcher_fixture(request): +def LineMatcher_fixture(request: FixtureRequest) -> "Type[LineMatcher]": return LineMatcher @pytest.fixture -def testdir(request, tmpdir_factory): +def testdir(request: FixtureRequest, tmpdir_factory) -> "Testdir": return Testdir(request, tmpdir_factory) @@ -365,7 +397,13 @@ class RunResult: :ivar duration: duration in seconds """ - def __init__(self, ret: Union[int, ExitCode], outlines, errlines, duration) -> None: + def __init__( + self, + ret: Union[int, ExitCode], + outlines: Sequence[str], + errlines: Sequence[str], + duration: float, + ) -> None: try: self.ret = pytest.ExitCode(ret) # type: Union[int, ExitCode] except ValueError: @@ -376,13 +414,13 @@ def __init__(self, ret: Union[int, ExitCode], outlines, errlines, duration) -> N self.stderr = LineMatcher(errlines) self.duration = duration - def __repr__(self): + def __repr__(self) -> str: return ( "" % (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration) ) - def parseoutcomes(self): + def parseoutcomes(self) -> Dict[str, int]: """Return a dictionary of outcomestring->num from parsing the terminal output that the test process produced. @@ -395,8 +433,14 @@ def parseoutcomes(self): raise ValueError("Pytest terminal summary report not found") def assert_outcomes( - self, passed=0, skipped=0, failed=0, error=0, xpassed=0, xfailed=0 - ): + self, + passed: int = 0, + skipped: int = 0, + failed: int = 0, + error: int = 0, + xpassed: int = 0, + xfailed: int = 0, + ) -> None: """Assert that the specified outcomes appear with the respective numbers (0 means it didn't occur) in the text output from a test run. @@ -422,19 +466,19 @@ def assert_outcomes( class CwdSnapshot: - def __init__(self): + def __init__(self) -> None: self.__saved = os.getcwd() - def restore(self): + def restore(self) -> None: os.chdir(self.__saved) class SysModulesSnapshot: - def __init__(self, preserve=None): + def __init__(self, preserve: Optional[Callable[[str], bool]] = None): self.__preserve = preserve self.__saved = dict(sys.modules) - def restore(self): + def restore(self) -> None: if self.__preserve: self.__saved.update( (k, m) for k, m in sys.modules.items() if self.__preserve(k) @@ -444,10 +488,10 @@ def restore(self): class SysPathsSnapshot: - def __init__(self): + def __init__(self) -> None: self.__saved = list(sys.path), list(sys.meta_path) - def restore(self): + def restore(self) -> None: sys.path[:], sys.meta_path[:] = self.__saved @@ -1422,7 +1466,7 @@ def _match_lines(self, lines2, match_func, match_nickname): :param str match_nickname: the nickname for the match function that will be logged to stdout when a match occurs """ - assert isinstance(lines2, Sequence) + assert isinstance(lines2, collections.abc.Sequence) lines2 = self._getlines(lines2) lines1 = self.lines[:] nextline = None diff --git a/src/_pytest/warning_types.py b/src/_pytest/warning_types.py index 80353ccbc8c..22cb17dbae6 100644 --- a/src/_pytest/warning_types.py +++ b/src/_pytest/warning_types.py @@ -1,6 +1,14 @@ +from typing import Any +from typing import Generic +from typing import TypeVar + import attr +if False: # TYPE_CHECKING + from typing import Type # noqa: F401 (used in type string) + + class PytestWarning(UserWarning): """ Bases: :class:`UserWarning`. @@ -72,7 +80,7 @@ class PytestExperimentalApiWarning(PytestWarning, FutureWarning): __module__ = "pytest" @classmethod - def simple(cls, apiname): + def simple(cls, apiname: str) -> "PytestExperimentalApiWarning": return cls( "{apiname} is an experimental api that may change over time".format( apiname=apiname @@ -103,17 +111,20 @@ class PytestUnknownMarkWarning(PytestWarning): __module__ = "pytest" +_W = TypeVar("_W", bound=PytestWarning) + + @attr.s -class UnformattedWarning: +class UnformattedWarning(Generic[_W]): """Used to hold warnings that need to format their message at runtime, as opposed to a direct message. Using this class avoids to keep all the warning types and messages in this module, avoiding misuse. """ - category = attr.ib() - template = attr.ib() + category = attr.ib(type="Type[_W]") + template = attr.ib(type=str) - def format(self, **kwargs): + def format(self, **kwargs: Any) -> _W: """Returns an instance of the warning category, formatted with given kwargs""" return self.category(self.template.format(**kwargs)) diff --git a/testing/acceptance_test.py b/testing/acceptance_test.py index 8fa6d734480..084cc8245a0 100644 --- a/testing/acceptance_test.py +++ b/testing/acceptance_test.py @@ -180,8 +180,14 @@ def test_not_collectable_arguments(self, testdir): p1 = testdir.makepyfile("") p2 = testdir.makefile(".pyc", "123") result = testdir.runpytest(p1, p2) - assert result.ret - result.stderr.fnmatch_lines(["*ERROR: not found:*{}".format(p2.basename)]) + assert result.ret == ExitCode.USAGE_ERROR + result.stderr.fnmatch_lines( + [ + "ERROR: not found: {}".format(p2), + "(no name {!r} in any of [[][]])".format(str(p2)), + "", + ] + ) @pytest.mark.filterwarnings("default") def test_better_reporting_on_conftest_load_failure(self, testdir, request): diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index 70a1fea5b1e..71b09a63090 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -1229,13 +1229,15 @@ def g(): @pytest.mark.parametrize( "reason, description", [ - ( + pytest.param( "cause", "The above exception was the direct cause of the following exception:", + id="cause", ), - ( + pytest.param( "context", "During handling of the above exception, another exception occurred:", + id="context", ), ], ) diff --git a/testing/code/test_source.py b/testing/code/test_source.py index 15e0bf24ade..5e7e1abf5a9 100644 --- a/testing/code/test_source.py +++ b/testing/code/test_source.py @@ -127,14 +127,15 @@ def test_isparseable(): class TestAccesses: - source = Source( - """\ - def f(x): - pass - def g(x): - pass - """ - ) + def setup_class(self): + self.source = Source( + """\ + def f(x): + pass + def g(x): + pass + """ + ) def test_getrange(self): x = self.source[0:2] @@ -155,14 +156,15 @@ def test_iter(self): class TestSourceParsingAndCompiling: - source = Source( - """\ - def f(x): - assert (x == - 3 + - 4) - """ - ).strip() + def setup_class(self): + self.source = Source( + """\ + def f(x): + assert (x == + 3 + + 4) + """ + ).strip() def test_compile(self): co = _pytest._code.compile("x=3") @@ -619,7 +621,8 @@ def test_multiline(): class TestTry: - source = """\ + def setup_class(self): + self.source = """\ try: raise ValueError except Something: @@ -646,7 +649,8 @@ def test_else(self): class TestTryFinally: - source = """\ + def setup_class(self): + self.source = """\ try: raise ValueError finally: @@ -663,7 +667,8 @@ def test_finally(self): class TestIf: - source = """\ + def setup_class(self): + self.source = """\ if 1: y = 3 elif False: diff --git a/testing/test_assertion.py b/testing/test_assertion.py index 93e7acfefbc..47d1ddafd9c 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -358,7 +358,7 @@ def test_list(self): @pytest.mark.parametrize( ["left", "right", "expected"], [ - ( + pytest.param( [0, 1], [0, 2], """ @@ -368,8 +368,9 @@ def test_list(self): + [0, 2] ? ^ """, + id="lists", ), - ( + pytest.param( {0: 1}, {0: 2}, """ @@ -379,8 +380,9 @@ def test_list(self): + {0: 2} ? ^ """, + id="dicts", ), - ( + pytest.param( {0, 1}, {0, 2}, """ @@ -390,6 +392,7 @@ def test_list(self): + {0, 2} ? ^ """, + id="sets", ), ], )