diff --git a/.gitignore b/.gitignore index 1e2c77e6..a26083d1 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ share/ .idea/ .vscode/ codealike.json +.python-version \ No newline at end of file diff --git a/docs/source/_static/debugging_support.png b/docs/source/_static/debugging_support.png new file mode 100644 index 00000000..9c12d6a5 Binary files /dev/null and b/docs/source/_static/debugging_support.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py index d58ae0e0..35b3dcc7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,7 +22,7 @@ author = 'Darren Burns' # The full version, including alpha/beta/rc tags -release = '0.51.2b0' +release = '0.52.0b0' # -- General configuration --------------------------------------------------- diff --git a/docs/source/guide/running_tests.rst b/docs/source/guide/running_tests.rst index ed0c6288..eb306081 100644 --- a/docs/source/guide/running_tests.rst +++ b/docs/source/guide/running_tests.rst @@ -203,4 +203,13 @@ output. .. image:: ../_static/show_diff_symbols.png :align: center :height: 150 - :alt: Ward output with diff symbols enabled \ No newline at end of file + :alt: Ward output with diff symbols enabled + +Debugging your code with ``pdb``/``breakpoint()`` +------------------------------------------------- + +Ward will automatically disable output capturing when you use `pdb.set_trace()` or `breakpoint()`, and re-enable it when you exit the debugger. + +.. image:: ../_static/debugging_support.png + :align: center + :alt: Ward debugging example diff --git a/pyproject.toml b/pyproject.toml index 9eb724f0..c666a926 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ path = ["ward/tests"] [tool.poetry] name = "ward" -version = "0.51.2b0" +version = "0.52.0b0" description = "A modern Python testing framework" exclude = ["ward/tests"] authors = ["Darren Burns "] diff --git a/ward/config.py b/ward/config.py index d1a0cd2a..ba5fdf14 100644 --- a/ward/config.py +++ b/ward/config.py @@ -12,6 +12,14 @@ CONFIG_FILE = "pyproject.toml" +def _breakpoint_supported() -> bool: + try: + breakpoint + except NameError: + return False + return True + + def read_config_toml(project_root: Path, config_file: str) -> Config: path = project_root / config_file if not path.is_file(): diff --git a/ward/debug.py b/ward/debug.py new file mode 100644 index 00000000..56adfe1e --- /dev/null +++ b/ward/debug.py @@ -0,0 +1,59 @@ +import importlib +import inspect +import io +import os +import warnings + +import click +import sys + +from ward.config import _breakpoint_supported +from ward.terminal import console + +original_stdout = sys.stdout + + +def init_breakpointhooks(pdb_module, sys_module) -> None: + # breakpoint() is Python 3.7+ + if _breakpoint_supported(): + sys_module.breakpointhook = _breakpointhook + pdb_module.set_trace = _breakpointhook + + +def _breakpointhook(*args, **kwargs): + hookname = os.getenv("PYTHONBREAKPOINT") + if hookname is None or len(hookname) == 0: + hookname = "pdb.set_trace" + kwargs.setdefault("frame", inspect.currentframe().f_back) + elif hookname == "0": + return None + + modname, dot, funcname = hookname.rpartition(".") + if dot == "": + modname = "builtins" + + try: + module = importlib.import_module(modname) + if hookname == "pdb.set_trace": + set_trace = module.Pdb(stdout=original_stdout, skip=["ward*"]).set_trace + hook = set_trace + else: + hook = getattr(module, funcname) + except: + warnings.warn( + f"Ignoring unimportable $PYTHONBREAKPOINT: {hookname}", RuntimeWarning + ) + return None + + context = click.get_current_context() + capture_enabled = context.params.get("capture_output") + capture_active = isinstance(sys.stdout, io.StringIO) + + if capture_enabled and capture_active: + sys.stdout = original_stdout + console.print(f"Entering {modname} - output capturing disabled.", style="info") + return hook(*args, **kwargs) + return hook(*args, **kwargs) + + +__breakpointhook__ = _breakpointhook diff --git a/ward/run.py b/ward/run.py index 6fd05675..eb0ed082 100644 --- a/ward/run.py +++ b/ward/run.py @@ -1,7 +1,9 @@ +import pdb + import sys from pathlib import Path from timeit import default_timer -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import click import click_completion @@ -19,6 +21,7 @@ filter_fixtures, ) from ward.config import set_defaults_from_config +from ward.debug import init_breakpointhooks from ward.rewrite import rewrite_assertions_in_tests from ward.suite import Suite from ward.fixtures import _DEFINED_FIXTURES @@ -139,6 +142,7 @@ def test( dry_run: bool, ): """Run tests.""" + init_breakpointhooks(pdb, sys) start_run = default_timer() paths = [Path(p) for p in path] mod_infos = get_info_for_modules(paths, exclude) @@ -146,7 +150,6 @@ def test( unfiltered_tests = get_tests_in_modules(modules, capture_output) filtered_tests = list(filter_tests(unfiltered_tests, query=search, tag_expr=tags,)) - # Rewrite assertions in each test tests = rewrite_assertions_in_tests(filtered_tests) time_to_collect = default_timer() - start_run diff --git a/ward/terminal.py b/ward/terminal.py index 08245cc3..6b6b3127 100644 --- a/ward/terminal.py +++ b/ward/terminal.py @@ -87,21 +87,6 @@ def make_indent(depth=1): console = Console(theme=theme, highlighter=NullHighlighter()) -def print_no_break(e: Any): - console.print(e, end="") - - -def multiline_description(s: str, indent: int, width: int) -> str: - wrapped = wrap(s, width) - if len(wrapped) == 1: - return wrapped[0] - rv = wrapped[0] - for line in wrapped[1:]: - indent_str = " " * indent - rv += f"\n{indent_str}{line}" - return rv - - def format_test_id(test_result: TestResult) -> str: """ Format module name, line number, and test case number diff --git a/ward/testing.py b/ward/testing.py index dd7ee94a..8a95eb0c 100644 --- a/ward/testing.py +++ b/ward/testing.py @@ -4,6 +4,7 @@ import inspect import traceback import uuid +from bdb import BdbQuit from collections import defaultdict from contextlib import ExitStack, closing, redirect_stderr, redirect_stdout from dataclasses import dataclass, field @@ -156,6 +157,12 @@ def run(self, cache: FixtureCache, dry_run=False) -> "TestResult": except FixtureError as e: outcome = TestOutcome.FAIL error: Optional[Exception] = e + except BdbQuit: + # We don't want to treat the user quitting the debugger + # as an exception, so we'll ignore BdbQuit. This will + # also prevent a large pdb-internal stack trace flooding + # the terminal. + pass except (Exception, SystemExit) as e: outcome = ( TestOutcome.XFAIL diff --git a/ward/tests/test_debug.py b/ward/tests/test_debug.py new file mode 100644 index 00000000..e374d8d8 --- /dev/null +++ b/ward/tests/test_debug.py @@ -0,0 +1,32 @@ +from types import SimpleNamespace +from unittest.mock import Mock + +from ward import test, debug +from ward.debug import init_breakpointhooks, _breakpointhook + + +@test("init_breakpointhooks always patches pdb.set_trace") +def _(): + mock_pdb = Mock() + init_breakpointhooks(pdb_module=mock_pdb, sys_module=Mock()) + assert mock_pdb.set_trace == _breakpointhook + + +@test("init_breakpointhooks sets sys.breakpointhook when it's supported") +def _(): + old_func = debug._breakpoint_supported + debug._breakpoint_supported = lambda: True + mock_sys = SimpleNamespace() + init_breakpointhooks(pdb_module=Mock(), sys_module=mock_sys) + debug._breakpoint_supported = old_func + assert mock_sys.breakpointhook == _breakpointhook + + +@test("init_breakpointhooks doesnt set breakpointhook when it's unsupported") +def _(): + old_func = debug._breakpoint_supported + debug._breakpoint_supported = lambda: False + mock_sys = SimpleNamespace() + init_breakpointhooks(pdb_module=Mock(), sys_module=mock_sys) + debug._breakpoint_supported = old_func + assert not hasattr(mock_sys, "breakpointhook") diff --git a/ward/tests/test_util.py b/ward/tests/test_util.py index 44049e9c..88415456 100644 --- a/ward/tests/test_util.py +++ b/ward/tests/test_util.py @@ -1,9 +1,8 @@ import os -import sys from pathlib import Path from ward import test, using, fixture -from ward.testing import each, xfail, skip +from ward.testing import each from ward.tests.utilities import make_project from ward.util import ( truncate,