Skip to content

Commit

Permalink
feat: support PEP 563 string replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
frostming committed Dec 2, 2022
1 parent 3b00a68 commit d4a74d2
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 17 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
repos:
- repo: https://github.com/frostming/fix-future-annotations
rev: 0.1.0
rev: 0.2.0
hooks:
- id: fix-future-annotations
exclude: tests/samples/
27 changes: 25 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,26 @@ str | None
</td></tr></tbody>
</table>

### [PEP 563] – Postponed Evaluation of Annotations

<table>
<thead>
<tr><th>Old</th><th>New</th></tr>
</thead>
<tbody>
<tr><td>

```python
def create() -> "Foo": pass
```
</td><td>

```python
def create() -> Foo: pass
```
</td></tr></tbody>
</table>

### Import aliases handling

<table>
Expand Down Expand Up @@ -91,7 +111,8 @@ def foo() -> tuple[str, int | None]:
```python
from typing import Union, Dict, Optional, Tuple

MyType = Union[str, int] # non-annotation usage will be preserved
# non-annotation usage will be preserved
MyType = Union[str, int]


def foo() -> Tuple[Dict[str, int], Optional[str]]:
Expand All @@ -104,7 +125,8 @@ from __future__ import annotations

from typing import Union

MyType = Union[str, int] # non-annotation usage will be preserved
# non-annotation usage will be preserved
MyType = Union[str, int]


def foo() -> tuple[dict[str, int], str | None]:
Expand Down Expand Up @@ -139,5 +161,6 @@ fix-future-annotations my_script.py

This work is distributed under [MIT](https://github.com/frostming/fix-future-annotations/blob/main/README.md) license.

[PEP 563]: https://peps.python.org/pep-0563/
[PEP 585]: https://peps.python.org/pep-0585/
[PEP 604]: https://peps.python.org/pep-0604/
5 changes: 5 additions & 0 deletions fix_future_annotations/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ def replace_name(i: int, tokens: list[Token], *, name: str, new: str) -> None:
tokens[i : j + 1] = [new_token]


def replace_string(i: int, tokens: list[Token], *, new: str) -> None:
new_token = tokens[i]._replace(name="CODE", src=new)
tokens[i] = new_token


def remove_name_from_import(i: int, tokens: list[Token], *, name: str) -> None:
while tokens[i].src != name:
i += 1
Expand Down
67 changes: 53 additions & 14 deletions fix_future_annotations/_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import sys
from functools import partial
from typing import Any, Callable, List
from typing import Any, Callable, List, NamedTuple

from tokenize_rt import NON_CODING_TOKENS, Offset, Token

Expand All @@ -15,6 +15,7 @@
remove_name_from_import,
remove_statement,
replace_name,
replace_string,
)

BASIC_COLLECTION_TYPES = frozenset(
Expand Down Expand Up @@ -130,16 +131,23 @@ def _fix_union(i: int, tokens: list[Token], *, arg_count: int) -> None:
del tokens[i:j]


class State(NamedTuple):
in_annotation: bool
in_literal: bool


class AnnotationVisitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self.token_funcs: dict[Offset, list[TokenFunc]] = {}

self._typing_import_name: str | None = None
self._typing_extensions_import_name: str | None = None
self._has_future_annotations = False
self._using_new_annotations = False
self._in_annotation_stack: list[bool] = [False]
self._state_stack: list[State] = []
self._typing_imports_to_remove: dict[str, str] = {}
self._literal_import_name: str | None = None
self._conditional_callbacks: list[
tuple[Callable[[], bool], Callable[[], None]]
] = []
Expand All @@ -155,15 +163,16 @@ def add_conditional_token_func(
)

def get_token_functions(self, tree: ast.Module) -> dict[Offset, list[TokenFunc]]:
self.visit(tree)
with self.under_state(State(False, False)):
self.visit(tree)
for condition, callback in self._conditional_callbacks:
if condition():
callback()
return self.token_funcs

@property
def in_annotation(self) -> bool:
return self._in_annotation_stack[-1]
def state(self) -> State:
return self._state_stack[-1]

@property
def need_future_annotations(self) -> bool:
Expand All @@ -172,18 +181,18 @@ def need_future_annotations(self) -> bool:
)

@contextlib.contextmanager
def visit_annotation(self) -> None:
self._in_annotation_stack.append(True)
def under_state(self, state: State) -> None:
self._state_stack.append(state)
try:
yield
finally:
self._in_annotation_stack.pop()
self._state_stack.pop()

def generic_visit(self, node: ast.AST) -> Any:
for field in reversed(node._fields):
value = getattr(node, field)
if field in {"annotation", "returns"}:
ctx = self.visit_annotation()
ctx = self.under_state(self.state._replace(in_annotation=True))
else:
ctx = contextlib.nullcontext()
with ctx:
Expand All @@ -197,7 +206,9 @@ def generic_visit(self, node: ast.AST) -> Any:
def visit_Import(self, node: ast.Import) -> Any:
for alias in node.names:
if alias.name == "typing":
self._typing_import_name = alias.asname or "typing"
self._typing_import_name = alias.asname or alias.name
elif alias.name == "typing_extensions":
self._typing_extensions_import_name = alias.asname or alias.name
return self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
Expand All @@ -208,6 +219,8 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
names: set[str] = {(alias.asname or alias.name) for alias in node.names}
for alias in reversed(node.names):
key = alias.asname or alias.name
if alias.name == "Literal":
self._literal_import_name = key
if alias.name in IMPORTS_TO_REMOVE:
self._typing_imports_to_remove[key] = alias.name
self.add_conditional_token_func(
Expand All @@ -221,13 +234,17 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
ast_to_offset(node),
remove_statement,
)
elif node.module == "typing_extensions":
alias = next((a for a in node.names if a.name == "Literal"), None)
if alias is not None:
self._literal_import_name = alias.asname or alias.name
else:
return self.generic_visit(node)

def visit_Attribute(self, node: ast.Attribute) -> Any:
"""Transform typing.List -> list"""
if (
self.in_annotation
self.state.in_annotation
and isinstance(node.value, ast.Name)
and node.value.id == self._typing_import_name
and node.attr in BASIC_COLLECTION_TYPES
Expand All @@ -241,7 +258,7 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
def visit_Name(self, node: ast.Name) -> Any:
if node.id in self._typing_imports_to_remove:
name = self._typing_imports_to_remove[node.id]
if not self.in_annotation:
if not self.state.in_annotation:
# It is referred to outside of an annotation, so we need to exclude it
self._conditional_callbacks.insert(
0,
Expand All @@ -261,12 +278,12 @@ def visit_Name(self, node: ast.Name) -> Any:
return self.generic_visit(node)

def visit_BinOp(self, node: ast.BinOp) -> Any:
if self.in_annotation:
if self.state.in_annotation:
self._using_new_annotations = True
return self.generic_visit(node)

def visit_Subscript(self, node: ast.Subscript) -> Any:
if not self.in_annotation:
if not self.state.in_annotation:
return self.generic_visit(node)
if isinstance(node.value, ast.Attribute):
if (
Expand All @@ -282,6 +299,14 @@ def visit_Subscript(self, node: ast.Subscript) -> Any:
ast_to_offset(node),
partial(_fix_union, arg_count=arg_count),
)
elif (
isinstance(node.value.value, ast.Name)
and node.value.value.id
in {self._typing_import_name, self._typing_extensions_import_name}
and node.value.attr == "Literal"
):
with self.under_state(self.state._replace(in_literal=True)):
return self.generic_visit(node)
elif isinstance(node.value, ast.Name):
if node.value.id in self._typing_imports_to_remove:
if self._typing_imports_to_remove[node.value.id] == "Optional":
Expand All @@ -295,4 +320,18 @@ def visit_Subscript(self, node: ast.Subscript) -> Any:
)
elif node.value.id in {name.lower() for name in BASIC_COLLECTION_TYPES}:
self._using_new_annotations = True
elif node.value.id == self._literal_import_name:
with self.under_state(self.state._replace(in_literal=True)):
return self.generic_visit(node)
return self.generic_visit(node)

def visit_Constant(self, node: ast.Constant) -> Any:
if (
self.state.in_annotation
and not self.state.in_literal
and isinstance(node.value, str)
):
self.add_token_func(
ast_to_offset(node), partial(replace_string, new=node.value)
)
return self.generic_visit(node)
7 changes: 7 additions & 0 deletions tests/samples/convert_string_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Literal


class Foo:
@classmethod
def create(cls, param: Literal["foo", "bar"]) -> "Foo":
pass
9 changes: 9 additions & 0 deletions tests/samples/convert_string_constants_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from typing import Literal


class Foo:
@classmethod
def create(cls, param: Literal["foo", "bar"]) -> Foo:
pass
7 changes: 7 additions & 0 deletions tests/samples/typing_extensions_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing_extensions import Literal


class Foo:
@classmethod
def create(cls, param: Literal["foo", "bar"]) -> "Foo":
pass
7 changes: 7 additions & 0 deletions tests/samples/typing_extensions_import2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import typing_extensions as te


class Foo:
@classmethod
def create(cls, param: te.Literal["foo", "bar"]) -> "Foo":
pass
9 changes: 9 additions & 0 deletions tests/samples/typing_extensions_import2_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

import typing_extensions as te


class Foo:
@classmethod
def create(cls, param: te.Literal["foo", "bar"]) -> Foo:
pass
9 changes: 9 additions & 0 deletions tests/samples/typing_extensions_import_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from typing_extensions import Literal


class Foo:
@classmethod
def create(cls, param: Literal["foo", "bar"]) -> Foo:
pass

0 comments on commit d4a74d2

Please sign in to comment.