diff --git a/fix_future_annotations/_main.py b/fix_future_annotations/_main.py index 65816e8..71fa903 100644 --- a/fix_future_annotations/_main.py +++ b/fix_future_annotations/_main.py @@ -4,7 +4,9 @@ import ast import difflib import sys +import os from pathlib import Path +from typing import Iterator from tokenize_rt import reversed_enumerate, src_to_tokens, tokens_to_src @@ -15,6 +17,20 @@ def _escaped(line: str) -> bool: return (len(line) - len(line.rstrip("\\"))) % 2 == 1 +def _iter_files(*paths: str) -> Iterator[str]: + def files_under_dir(path: str) -> Iterator[str]: + for root, _, files in os.walk(path): + for filename in files: + if filename.endswith(".py"): + yield os.path.join(root, filename) + + for path in paths: + if os.path.isdir(path): + yield from files_under_dir(path) + elif path.endswith(".py"): + yield path + + def _add_future_annotations(content: str) -> str: """Add from __future__ annotations after the first docstring and comments""" new_lines = ["from __future__ import annotations\n"] @@ -106,7 +122,7 @@ def fix_file( def main(argv: list[str] | None = None) -> None: parser = argparse.ArgumentParser() - parser.add_argument("filenames", nargs="+", help="File(s) to fix") + parser.add_argument("path", nargs="+", help="File or directory path(s) to fix") parser.add_argument( "--check", "-c", @@ -120,16 +136,18 @@ def main(argv: list[str] | None = None) -> None: ) args = parser.parse_args(argv) diff_count = 0 - for filename in args.filenames: - if filename.endswith(".py"): - result = fix_file(filename, args.write, show_diff=args.verbose) - diff_count += int(result) + checked = 0 + for filename in _iter_files(*args.path): + checked += 1 + result = fix_file(filename, args.write, show_diff=args.verbose) + diff_count += int(result) if diff_count: if args.write: message = f"All complete, {diff_count} files were fixed" else: message = f"All complete, {diff_count} files need to be fixed" - print(message) + if checked > 1: # multiple mode, print a summary + print(message) sys.exit(1) - else: + elif checked > 1: print("All complete, no file is changed")