diff --git a/README.md b/README.md index 95009f5..3e4e3cc 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,14 @@ def foo() -> tuple[dict[str, int], str | None]: Unused import names will be removed, and if `from __future__ import annotations` is not found in the script, it will be automatically added if the new syntax is being used. +## Use as a command line tool + +```bash +python3 -m pip install -U fix-future-annotations + +fix-future-annotations my_script.py +``` + ## Use as pre-commit hook Add the following to your `.pre-commit-config.yaml`: @@ -149,12 +157,21 @@ repos: - id: fix-future-annotations ``` -## Use as command line tool +## Configurations -```bash -python3 -m pip install -U fix-future-annotations +`fix-future-annotations` can be configured via `pyproject.toml`. Here is an example: -fix-future-annotations my_script.py +```toml +[tool.fix_future_annotations] +exclude_files = [ # regex patterns to exclude files + 'tests/.*', + 'docs/.*', +] + +exclude_lines = [ # regex patterns to exclude lines + '# ffa: ignore', # if a line ends with this comment, the whole *block* will be excluded + 'class .+\(BaseModel\):' # classes that inherit from `BaseModel` will be excluded +] ``` ## License diff --git a/fix_future_annotations/__init__.py b/fix_future_annotations/__init__.py index e69de29..02866c3 100644 --- a/fix_future_annotations/__init__.py +++ b/fix_future_annotations/__init__.py @@ -0,0 +1,4 @@ +from fix_future_annotations._main import fix_file + + +__all__ = ["fix_file"] diff --git a/fix_future_annotations/_config.py b/fix_future_annotations/_config.py new file mode 100644 index 0000000..2deaa82 --- /dev/null +++ b/fix_future_annotations/_config.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +import re +import sys + +if sys.version_info >= (3, 11): + import tomllib +else: + import tomli as tomllib + + +@dataclass +class Config: + """Configuration for fix_future_annotations.""" + + # The line patterns(regex) to exclude from the fix. + exclude_lines: list[str] = field(default_factory=list) + # The file patterns(regex) to exclude from the fix. + exclude_files: list[str] = field(default_factory=list) + + @classmethod + def from_file(cls, path: str | Path = "pyproject.toml") -> Config: + """Load the configuration from a file.""" + try: + with open(path, "rb") as f: + data = tomllib.load(f) + except OSError: + return cls() + else: + return cls(**data.get("tool", {}).get("fix_future_annotations", {})) + + def is_file_excluded(self, file_path: str) -> bool: + return any(re.search(pattern, file_path) for pattern in self.exclude_files) + + def is_line_excluded(self, line: str) -> bool: + return any(re.search(pattern, line) for pattern in self.exclude_lines) diff --git a/fix_future_annotations/_main.py b/fix_future_annotations/_main.py index 71fa903..7706b3b 100644 --- a/fix_future_annotations/_main.py +++ b/fix_future_annotations/_main.py @@ -10,6 +10,7 @@ from tokenize_rt import reversed_enumerate, src_to_tokens, tokens_to_src +from fix_future_annotations._config import Config from fix_future_annotations._visitor import AnnotationVisitor @@ -17,17 +18,19 @@ def _escaped(line: str) -> bool: return (len(line) - len(line.rstrip("\\"))) % 2 == 1 -def _iter_files(*paths: str) -> Iterator[str]: +def _iter_files(*paths: str, config: Config) -> Iterator[str]: def files_under_dir(path: str) -> Iterator[str]: for root, _, files in os.walk(path): for filename in files: if filename.endswith(".py"): - yield os.path.join(root, filename) + fn = os.path.join(root, filename).replace("\\", "/") + if not config.is_file_excluded(fn): + yield fn for path in paths: if os.path.isdir(path): yield from files_under_dir(path) - elif path.endswith(".py"): + elif path.endswith(".py") and not config.is_file_excluded(path): yield path @@ -82,14 +85,20 @@ def _add_future_annotations(content: str) -> str: def fix_file( - file_path: str | Path, write: bool = False, show_diff: bool = False + file_path: str | Path, + *, + write: bool = False, + show_diff: bool = False, + config: Config | None = None, ) -> bool: """Fix the file at file_path to use PEP 585, 604 and 563 syntax.""" + if config is None: + config = Config.from_file() file_path = Path(file_path) file_content = file_path.read_text("utf-8") tokens = src_to_tokens(file_content) tree = ast.parse(file_content) - visitor = AnnotationVisitor() + visitor = AnnotationVisitor(file_content.splitlines(), config=config) token_funcs = visitor.get_token_functions(tree) for i, token in reversed_enumerate(tokens): if not token.src: @@ -137,9 +146,12 @@ def main(argv: list[str] | None = None) -> None: args = parser.parse_args(argv) diff_count = 0 checked = 0 - for filename in _iter_files(*args.path): + config = Config.from_file() + for filename in _iter_files(*args.path, config=config): checked += 1 - result = fix_file(filename, args.write, show_diff=args.verbose) + result = fix_file( + filename, write=args.write, show_diff=args.verbose, config=config + ) diff_count += int(result) if diff_count: if args.write: diff --git a/fix_future_annotations/_visitor.py b/fix_future_annotations/_visitor.py index 21091f6..f5c914d 100644 --- a/fix_future_annotations/_visitor.py +++ b/fix_future_annotations/_visitor.py @@ -8,6 +8,7 @@ from tokenize_rt import NON_CODING_TOKENS, Offset, Token +from fix_future_annotations._config import Config from fix_future_annotations._utils import ( ast_to_offset, find_closing_bracket, @@ -134,11 +135,17 @@ def _fix_union(i: int, tokens: list[Token], *, arg_count: int) -> None: class State(NamedTuple): in_annotation: bool in_literal: bool + omit: bool + + def update_annotation(self) -> bool: + return self.in_annotation and not self.omit class AnnotationVisitor(ast.NodeVisitor): - def __init__(self) -> None: + def __init__(self, lines: list[str], *, config: Config) -> None: super().__init__() + self.lines = lines + self.config = config self.token_funcs: dict[Offset, list[TokenFunc]] = {} self._typing_import_name: str | None = None @@ -162,8 +169,12 @@ def add_conditional_token_func( (condition, partial(self.add_token_func, offset, func)) ) + def _is_excluded(self, node: ast.AST) -> bool: + line = self.lines[node.lineno - 1] + return self.config.is_line_excluded(line) + def get_token_functions(self, tree: ast.Module) -> dict[Offset, list[TokenFunc]]: - with self.under_state(State(False, False)): + with self.under_state(State(False, False, False)): self.visit(tree) for condition, callback in self._conditional_callbacks: if condition(): @@ -188,6 +199,14 @@ def under_state(self, state: State) -> None: finally: self._state_stack.pop() + def visit(self, node: ast.AST) -> Any: + if isinstance(node, ast.stmt) and self._is_excluded(node): + ctx = self.under_state(self.state._replace(omit=True)) + else: + ctx = contextlib.nullcontext() + with ctx: + return super().visit(node) + def generic_visit(self, node: ast.AST) -> Any: for field in reversed(node._fields): value = getattr(node, field) @@ -244,7 +263,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> Any: def visit_Attribute(self, node: ast.Attribute) -> Any: """Transform typing.List -> list""" if ( - self.state.in_annotation + self.state.update_annotation() and isinstance(node.value, ast.Name) and node.value.id == self._typing_import_name and node.attr in BASIC_COLLECTION_TYPES @@ -258,7 +277,7 @@ def visit_Attribute(self, node: ast.Attribute) -> Any: def visit_Name(self, node: ast.Name) -> Any: if node.id in self._typing_imports_to_remove: name = self._typing_imports_to_remove[node.id] - if not self.state.in_annotation: + if not self.state.update_annotation(): # It is referred to outside of an annotation, so we need to exclude it self._conditional_callbacks.insert( 0, @@ -283,7 +302,7 @@ def visit_BinOp(self, node: ast.BinOp) -> Any: return self.generic_visit(node) def visit_Subscript(self, node: ast.Subscript) -> Any: - if not self.state.in_annotation: + if not self.state.update_annotation(): return self.generic_visit(node) if isinstance(node.value, ast.Attribute): if ( @@ -327,7 +346,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: def visit_Constant(self, node: ast.Constant) -> Any: if ( - self.state.in_annotation + self.state.update_annotation() and not self.state.in_literal and isinstance(node.value, str) ): diff --git a/pdm.lock b/pdm.lock index f8b37e8..60a67c1 100644 --- a/pdm.lock +++ b/pdm.lock @@ -70,8 +70,8 @@ requires_python = ">=3.7" summary = "A lil' TOML parser" [metadata] -lock_version = "4.0" -content_hash = "sha256:13ed22f7e12e7c7c3ee7cdeb18d0a3764932cac684d130d177be14ad0546d577" +lock_version = "4.1" +content_hash = "sha256:b50b8c49b0bcf0f0b3c1d684310b54ec461bdd90ce8804ff7e57f17092d2e57b" [metadata.files] "attrs 22.1.0" = [ diff --git a/pyproject.toml b/pyproject.toml index e93ed49..f365fac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ authors = [ ] dependencies = [ "tokenize-rt>=5.0.0", + "tomli; python_version < '3.11'", ] requires-python = ">=3.8" readme = "README.md" @@ -49,3 +50,9 @@ exclude = ''' | tests/samples )/ ''' + + +[tool.fix_future_annotations] +exclude_lines = [ + "# ffa: ignore" +] diff --git a/tests/samples/exclude_lines.py b/tests/samples/exclude_lines.py new file mode 100644 index 0000000..c42eabb --- /dev/null +++ b/tests/samples/exclude_lines.py @@ -0,0 +1,19 @@ +from typing import List, Optional, Tuple, Union + + +class NoFix: + def __init__(self, names: List[str]) -> None: + self.names = names + + def lengh(self) -> Optional[int]: + if self.names: + return len(self.names) + return None + + +def foo() -> Union[str, int]: # ffa: ignore + return 42 + + +def bar() -> Tuple[str, int]: + return "bar", 42 diff --git a/tests/samples/exclude_lines_fix.py b/tests/samples/exclude_lines_fix.py new file mode 100644 index 0000000..b6be68b --- /dev/null +++ b/tests/samples/exclude_lines_fix.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import List, Optional, Union + + +class NoFix: + def __init__(self, names: List[str]) -> None: + self.names = names + + def lengh(self) -> Optional[int]: + if self.names: + return len(self.names) + return None + + +def foo() -> Union[str, int]: # ffa: ignore + return 42 + + +def bar() -> tuple[str, int]: + return "bar", 42 diff --git a/tests/test_fix_future_annotations.py b/tests/test_fix_future_annotations.py index 5190bde..786565f 100644 --- a/tests/test_fix_future_annotations.py +++ b/tests/test_fix_future_annotations.py @@ -3,6 +3,7 @@ import pytest from fix_future_annotations._main import fix_file +from fix_future_annotations._config import Config SAMPLES = Path(__file__).with_name("samples") @@ -19,9 +20,10 @@ def _load_samples() -> list: @pytest.mark.parametrize("origin, fixed", _load_samples()) def test_fix_samples(origin: Path, fixed: Path, tmp_path: Path) -> None: copied = shutil.copy2(origin, tmp_path) - result = fix_file(copied, True) + config = Config(exclude_lines=["# ffa: ignore", "class NoFix:"]) + result = fix_file(copied, write=True, config=config) assert fixed.read_text() == Path(copied).read_text() - result = fix_file(copied, False) + result = fix_file(copied, write=False, config=config) assert not result