From d4a74d2fd12574fee20647bc4390141b0abfdeb6 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Fri, 2 Dec 2022 18:28:50 +0800 Subject: [PATCH] feat: support PEP 563 string replacement --- .pre-commit-config.yaml | 3 +- README.md | 27 +++++++- fix_future_annotations/_utils.py | 5 ++ fix_future_annotations/_visitor.py | 67 +++++++++++++++---- tests/samples/convert_string_constants.py | 7 ++ tests/samples/convert_string_constants_fix.py | 9 +++ tests/samples/typing_extensions_import.py | 7 ++ tests/samples/typing_extensions_import2.py | 7 ++ .../samples/typing_extensions_import2_fix.py | 9 +++ tests/samples/typing_extensions_import_fix.py | 9 +++ 10 files changed, 133 insertions(+), 17 deletions(-) create mode 100644 tests/samples/convert_string_constants.py create mode 100644 tests/samples/convert_string_constants_fix.py create mode 100644 tests/samples/typing_extensions_import.py create mode 100644 tests/samples/typing_extensions_import2.py create mode 100644 tests/samples/typing_extensions_import2_fix.py create mode 100644 tests/samples/typing_extensions_import_fix.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bb79940..9b80925 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,6 @@ repos: - repo: https://github.com/frostming/fix-future-annotations - rev: 0.1.0 + rev: 0.2.0 hooks: - id: fix-future-annotations + exclude: tests/samples/ diff --git a/README.md b/README.md index b4a4644..dcc37e1 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,26 @@ str | None +### [PEP 563] – Postponed Evaluation of Annotations + + + + + + + +
OldNew
+ +```python +def create() -> "Foo": pass +``` + + +```python +def create() -> Foo: pass +``` +
+ ### Import aliases handling @@ -91,7 +111,8 @@ def foo() -> tuple[str, int | None]: ```python from typing import Union, Dict, Optional, Tuple -MyType = Union[str, int] # non-annotation usage will be preserved +# non-annotation usage will be preserved +MyType = Union[str, int] def foo() -> Tuple[Dict[str, int], Optional[str]]: @@ -104,7 +125,8 @@ from __future__ import annotations from typing import Union -MyType = Union[str, int] # non-annotation usage will be preserved +# non-annotation usage will be preserved +MyType = Union[str, int] def foo() -> tuple[dict[str, int], str | None]: @@ -139,5 +161,6 @@ fix-future-annotations my_script.py This work is distributed under [MIT](https://github.com/frostming/fix-future-annotations/blob/main/README.md) license. +[PEP 563]: https://peps.python.org/pep-0563/ [PEP 585]: https://peps.python.org/pep-0585/ [PEP 604]: https://peps.python.org/pep-0604/ diff --git a/fix_future_annotations/_utils.py b/fix_future_annotations/_utils.py index 0110d8f..c4aa50f 100644 --- a/fix_future_annotations/_utils.py +++ b/fix_future_annotations/_utils.py @@ -17,6 +17,11 @@ def replace_name(i: int, tokens: list[Token], *, name: str, new: str) -> None: tokens[i : j + 1] = [new_token] +def replace_string(i: int, tokens: list[Token], *, new: str) -> None: + new_token = tokens[i]._replace(name="CODE", src=new) + tokens[i] = new_token + + def remove_name_from_import(i: int, tokens: list[Token], *, name: str) -> None: while tokens[i].src != name: i += 1 diff --git a/fix_future_annotations/_visitor.py b/fix_future_annotations/_visitor.py index a4b2fb3..21091f6 100644 --- a/fix_future_annotations/_visitor.py +++ b/fix_future_annotations/_visitor.py @@ -4,7 +4,7 @@ import contextlib import sys from functools import partial -from typing import Any, Callable, List +from typing import Any, Callable, List, NamedTuple from tokenize_rt import NON_CODING_TOKENS, Offset, Token @@ -15,6 +15,7 @@ remove_name_from_import, remove_statement, replace_name, + replace_string, ) BASIC_COLLECTION_TYPES = frozenset( @@ -130,16 +131,23 @@ def _fix_union(i: int, tokens: list[Token], *, arg_count: int) -> None: del tokens[i:j] +class State(NamedTuple): + in_annotation: bool + in_literal: bool + + class AnnotationVisitor(ast.NodeVisitor): def __init__(self) -> None: super().__init__() self.token_funcs: dict[Offset, list[TokenFunc]] = {} self._typing_import_name: str | None = None + self._typing_extensions_import_name: str | None = None self._has_future_annotations = False self._using_new_annotations = False - self._in_annotation_stack: list[bool] = [False] + self._state_stack: list[State] = [] self._typing_imports_to_remove: dict[str, str] = {} + self._literal_import_name: str | None = None self._conditional_callbacks: list[ tuple[Callable[[], bool], Callable[[], None]] ] = [] @@ -155,15 +163,16 @@ def add_conditional_token_func( ) def get_token_functions(self, tree: ast.Module) -> dict[Offset, list[TokenFunc]]: - self.visit(tree) + with self.under_state(State(False, False)): + self.visit(tree) for condition, callback in self._conditional_callbacks: if condition(): callback() return self.token_funcs @property - def in_annotation(self) -> bool: - return self._in_annotation_stack[-1] + def state(self) -> State: + return self._state_stack[-1] @property def need_future_annotations(self) -> bool: @@ -172,18 +181,18 @@ def need_future_annotations(self) -> bool: ) @contextlib.contextmanager - def visit_annotation(self) -> None: - self._in_annotation_stack.append(True) + def under_state(self, state: State) -> None: + self._state_stack.append(state) try: yield finally: - self._in_annotation_stack.pop() + self._state_stack.pop() def generic_visit(self, node: ast.AST) -> Any: for field in reversed(node._fields): value = getattr(node, field) if field in {"annotation", "returns"}: - ctx = self.visit_annotation() + ctx = self.under_state(self.state._replace(in_annotation=True)) else: ctx = contextlib.nullcontext() with ctx: @@ -197,7 +206,9 @@ def generic_visit(self, node: ast.AST) -> Any: def visit_Import(self, node: ast.Import) -> Any: for alias in node.names: if alias.name == "typing": - self._typing_import_name = alias.asname or "typing" + self._typing_import_name = alias.asname or alias.name + elif alias.name == "typing_extensions": + self._typing_extensions_import_name = alias.asname or alias.name return self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> Any: @@ -208,6 +219,8 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> Any: names: set[str] = {(alias.asname or alias.name) for alias in node.names} for alias in reversed(node.names): key = alias.asname or alias.name + if alias.name == "Literal": + self._literal_import_name = key if alias.name in IMPORTS_TO_REMOVE: self._typing_imports_to_remove[key] = alias.name self.add_conditional_token_func( @@ -221,13 +234,17 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> Any: ast_to_offset(node), remove_statement, ) + elif node.module == "typing_extensions": + alias = next((a for a in node.names if a.name == "Literal"), None) + if alias is not None: + self._literal_import_name = alias.asname or alias.name else: return self.generic_visit(node) def visit_Attribute(self, node: ast.Attribute) -> Any: """Transform typing.List -> list""" if ( - self.in_annotation + self.state.in_annotation and isinstance(node.value, ast.Name) and node.value.id == self._typing_import_name and node.attr in BASIC_COLLECTION_TYPES @@ -241,7 +258,7 @@ def visit_Attribute(self, node: ast.Attribute) -> Any: def visit_Name(self, node: ast.Name) -> Any: if node.id in self._typing_imports_to_remove: name = self._typing_imports_to_remove[node.id] - if not self.in_annotation: + if not self.state.in_annotation: # It is referred to outside of an annotation, so we need to exclude it self._conditional_callbacks.insert( 0, @@ -261,12 +278,12 @@ def visit_Name(self, node: ast.Name) -> Any: return self.generic_visit(node) def visit_BinOp(self, node: ast.BinOp) -> Any: - if self.in_annotation: + if self.state.in_annotation: self._using_new_annotations = True return self.generic_visit(node) def visit_Subscript(self, node: ast.Subscript) -> Any: - if not self.in_annotation: + if not self.state.in_annotation: return self.generic_visit(node) if isinstance(node.value, ast.Attribute): if ( @@ -282,6 +299,14 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: ast_to_offset(node), partial(_fix_union, arg_count=arg_count), ) + elif ( + isinstance(node.value.value, ast.Name) + and node.value.value.id + in {self._typing_import_name, self._typing_extensions_import_name} + and node.value.attr == "Literal" + ): + with self.under_state(self.state._replace(in_literal=True)): + return self.generic_visit(node) elif isinstance(node.value, ast.Name): if node.value.id in self._typing_imports_to_remove: if self._typing_imports_to_remove[node.value.id] == "Optional": @@ -295,4 +320,18 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: ) elif node.value.id in {name.lower() for name in BASIC_COLLECTION_TYPES}: self._using_new_annotations = True + elif node.value.id == self._literal_import_name: + with self.under_state(self.state._replace(in_literal=True)): + return self.generic_visit(node) + return self.generic_visit(node) + + def visit_Constant(self, node: ast.Constant) -> Any: + if ( + self.state.in_annotation + and not self.state.in_literal + and isinstance(node.value, str) + ): + self.add_token_func( + ast_to_offset(node), partial(replace_string, new=node.value) + ) return self.generic_visit(node) diff --git a/tests/samples/convert_string_constants.py b/tests/samples/convert_string_constants.py new file mode 100644 index 0000000..9f4f22c --- /dev/null +++ b/tests/samples/convert_string_constants.py @@ -0,0 +1,7 @@ +from typing import Literal + + +class Foo: + @classmethod + def create(cls, param: Literal["foo", "bar"]) -> "Foo": + pass diff --git a/tests/samples/convert_string_constants_fix.py b/tests/samples/convert_string_constants_fix.py new file mode 100644 index 0000000..2805023 --- /dev/null +++ b/tests/samples/convert_string_constants_fix.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from typing import Literal + + +class Foo: + @classmethod + def create(cls, param: Literal["foo", "bar"]) -> Foo: + pass diff --git a/tests/samples/typing_extensions_import.py b/tests/samples/typing_extensions_import.py new file mode 100644 index 0000000..d9277de --- /dev/null +++ b/tests/samples/typing_extensions_import.py @@ -0,0 +1,7 @@ +from typing_extensions import Literal + + +class Foo: + @classmethod + def create(cls, param: Literal["foo", "bar"]) -> "Foo": + pass diff --git a/tests/samples/typing_extensions_import2.py b/tests/samples/typing_extensions_import2.py new file mode 100644 index 0000000..5d9ec67 --- /dev/null +++ b/tests/samples/typing_extensions_import2.py @@ -0,0 +1,7 @@ +import typing_extensions as te + + +class Foo: + @classmethod + def create(cls, param: te.Literal["foo", "bar"]) -> "Foo": + pass diff --git a/tests/samples/typing_extensions_import2_fix.py b/tests/samples/typing_extensions_import2_fix.py new file mode 100644 index 0000000..490243b --- /dev/null +++ b/tests/samples/typing_extensions_import2_fix.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +import typing_extensions as te + + +class Foo: + @classmethod + def create(cls, param: te.Literal["foo", "bar"]) -> Foo: + pass diff --git a/tests/samples/typing_extensions_import_fix.py b/tests/samples/typing_extensions_import_fix.py new file mode 100644 index 0000000..fb677ff --- /dev/null +++ b/tests/samples/typing_extensions_import_fix.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from typing_extensions import Literal + + +class Foo: + @classmethod + def create(cls, param: Literal["foo", "bar"]) -> Foo: + pass