Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support narrowing literals and enums using the in operator in combination with list, set, and tuple expressions. #17044

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion docs/source/literal_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ perform an exhaustiveness check, you need to update your code to use an

.. code-block:: python

from typing import Literal, NoReturn
from typing import Literal
from typing_extensions import assert_never

PossibleValues = Literal['one', 'two']
Expand Down Expand Up @@ -368,6 +368,19 @@ without a value:
elif x == 'two':
return False

For the sake of brevity, you can use the ``in`` operator in combination with
list, set, or tuple expressions (lists, sets, or tuples created "on the fly"):

.. code-block:: python

PossibleValues = Literal['one', 'two', 'three']

def validate(x: PossibleValues) -> bool:
if x in ['one']:
return True
elif x in ('two', 'three'):
return False

Exhaustiveness checking is also supported for match statements (Python 3.10 and later):

.. code-block:: python
Expand Down
61 changes: 61 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import itertools
from collections import defaultdict
from contextlib import ExitStack, contextmanager
from copy import copy
from typing import (
AbstractSet,
Callable,
Expand Down Expand Up @@ -116,6 +117,7 @@
RaiseStmt,
RefExpr,
ReturnStmt,
SetExpr,
StarExpr,
Statement,
StrExpr,
Expand Down Expand Up @@ -4684,12 +4686,71 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
if self.in_checked_function():
self.fail(message_registry.RETURN_VALUE_EXPECTED, s)

def _transform_sequence_expressions_for_narrowing_with_in(self, e: Expression) -> Expression:
"""
Transform an expression like

(x is None) and (x in (1, 2)) and (x not in [3, 4])

into

(x is None) and (x == 1 or x == 2) and (x != 3 and x != 4)

This transformation is supposed to enable narrowing literals and enums using the in
(and the not in) operator in combination with tuple expressions without the need to
implement additional narrowing logic.
"""
if isinstance(e, OpExpr):
e.left = self._transform_sequence_expressions_for_narrowing_with_in(e.left)
e.right = self._transform_sequence_expressions_for_narrowing_with_in(e.right)
return e

if not (
isinstance(e, ComparisonExpr)
and isinstance(left := e.operands[0], NameExpr)
and ((op_in := e.operators[0]) in ("in", "not in"))
and isinstance(litu := e.operands[1], (ListExpr, SetExpr, TupleExpr))
):
return e

op_eq, op_con = (["=="], "or") if (op_in == "in") else (["!="], "and")
line = e.line
left_new = left
comparisons = []
for right in reversed(litu.items):
if isinstance(right, StarExpr):
return e
comparison = ComparisonExpr(op_eq, [left_new, right])
comparison.line = line
comparisons.append(comparison)
left_new = copy(left)
if (nmb := len(comparisons)) == 0:
if op_in == "in":
e = NameExpr("False")
e.fullname = "builtins.False"
e.line = line
return e
e = NameExpr("True")
e.fullname = "builtins.True"
e.line = line
return e
if nmb == 1:
return comparisons[0]
e = OpExpr(op_con, comparisons[1], comparisons[0])
for comparison in comparisons[2:]:
e = OpExpr(op_con, comparison, e)
e.line = line
return e

def visit_if_stmt(self, s: IfStmt) -> None:
"""Type check an if statement."""
# This frame records the knowledge from previous if/elif clauses not being taken.
# Fall-through to the original frame is handled explicitly in each block.
with self.binder.frame_context(can_skip=False, conditional_frame=True, fall_through=0):
for e, b in zip(s.expr, s.body):

e = self._transform_sequence_expressions_for_narrowing_with_in(e)

t = get_proper_type(self.expr_checker.accept(e))

if isinstance(t, DeletedType):
Expand Down
4 changes: 3 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,7 +1876,9 @@ class NameExpr(RefExpr):

__match_args__ = ("name", "node")

def __init__(self, name: str) -> None:
def __init__(self, name: str = "?") -> None:
# The default name "?" aims to make NameExpr mypyc copyable.
# Always pass a proper name when manually calling NameExpr.__init__.
super().__init__()
self.name = name # Name referred to
# Is this a l.h.s. of a special form assignment like typed dict or type variable?
Expand Down
9 changes: 9 additions & 0 deletions mypyc/test-data/run-tuples.test
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,12 @@ TUPLE: Final[Tuple[str, ...]] = ('x', 'y')
def test_final_boxed_tuple() -> None:
t = TUPLE
assert t == ('x', 'y')

[case testTupleDoNotCrashOnTransformedInComparisons]
def f() -> None:
for n in ["x"]:
if n in ("x", "z") or n.startswith("y"):
print(n)
f()
[out]
x
2 changes: 1 addition & 1 deletion test-data/unit/check-isinstance.test
Original file line number Diff line number Diff line change
Expand Up @@ -1993,7 +1993,7 @@ class C(A): pass

y: Optional[B]
if y in (B(), C()):
reveal_type(y) # N: Revealed type is "__main__.B"
reveal_type(y) # N: Revealed type is "Union[__main__.B, None]"
else:
reveal_type(y) # N: Revealed type is "Union[__main__.B, None]"
[builtins fixtures/tuple.pyi]
Expand Down
101 changes: 101 additions & 0 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -2333,3 +2333,104 @@ def f(x: C) -> None:

f(C(5))
[builtins fixtures/primitives.pyi]

[case testNarrowLiteralsInListOrSetOrTupleExpression]
# flags: --warn-unreachable

from typing import Optional
from typing_extensions import Literal

x: int

def f(v: Optional[Literal[1, 2, 3, 4]]) -> None:
if v in (0, 1, 2):
reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2]]"
elif v in [1]:
reveal_type(v) # E: Statement is unreachable
elif v is None or v in {3, x}:
reveal_type(v) # N: Revealed type is "Union[Literal[3], Literal[4], None]"
elif v in ():
reveal_type(v) # E: Statement is unreachable
else:
reveal_type(v) # N: Revealed type is "Literal[4]"
reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2], Literal[3], Literal[4], None]"
[builtins fixtures/primitives.pyi]

[case testNarrowLiteralsNotInListOrSetOrTupleExpression]
# flags: --warn-unreachable

from typing import Optional
from typing_extensions import Literal

x: int

def f(v: Optional[Literal[1, 2, 3, 4, 5]]) -> None:
if v not in {0, 1, 2, 3}:
reveal_type(v) # N: Revealed type is "Union[Literal[4], Literal[5], None]"
elif v not in [1, 2, 3, 4]: # E: Right operand of "and" is never evaluated
reveal_type(v) # E: Statement is unreachable
elif v is not None and v not in (3,):
reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2]]"
elif v not in (x, 3):
reveal_type(v) # E: Statement is unreachable
else:
reveal_type(v) # N: Revealed type is "Literal[3]"
reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2], Literal[3], Literal[4], Literal[5], None]"
[builtins fixtures/primitives.pyi]

[case testNarrowEnumsInListOrSetOrTupleExpression]
from enum import Enum
from typing import Final

class E(Enum):
A = 1
B = 2
C = 3
D = 4

A: Final = E.A
C: Final = E.C

def f(v: E) -> None:
reveal_type(v) # N: Revealed type is "__main__.E"
if v in (A, E.B):
reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]"
elif v in [E.A]:
reveal_type(v)
elif v in {C}:
reveal_type(v) # N: Revealed type is "Literal[__main__.E.C]"
elif v in ():
reveal_type(v)
else:
reveal_type(v) # N: Revealed type is "Literal[__main__.E.D]"
reveal_type(v) # N: Revealed type is "__main__.E"
[builtins fixtures/primitives.pyi]

[case testNarrowEnumsNotInListOrSetOrTupleExpression]
from enum import Enum
from typing import Final

class E(Enum):
A = 1
B = 2
C = 3
D = 4
E = 5

A: Final = E.A
C: Final = E.C

def f(v: E) -> None:
reveal_type(v) # N: Revealed type is "__main__.E"
if v not in (A, E.B, E.C):
reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.D], Literal[__main__.E.E]]"
elif v not in [E.A, E.B, E.C, E.C]:
reveal_type(v)
elif v not in {C}:
reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]"
elif v not in []:
reveal_type(v) # N: Revealed type is "Literal[__main__.E.C]"
else:
reveal_type(v)
reveal_type(v) # N: Revealed type is "__main__.E"
[builtins fixtures/primitives.pyi]
1 change: 1 addition & 0 deletions test-data/unit/fixtures/tuple.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ _Tco = TypeVar('_Tco', covariant=True)

class object:
def __init__(self) -> None: pass
def __eq__(self, other: object) -> bool: pass

class type:
def __init__(self, *a: object) -> None: pass
Expand Down
Loading