Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support .value inference for Union of enums #15939

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 76 additions & 60 deletions mypy/plugins/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@
from mypy.semanal_enum import ENUM_BASES
from mypy.subtypes import is_equivalent
from mypy.typeops import fixup_partial_type, make_simplified_union
from mypy.types import CallableType, Instance, LiteralType, ProperType, Type, get_proper_type
from mypy.types import (
CallableType,
Instance,
LiteralType,
ProperType,
Type,
UnionType,
get_proper_type,
)

ENUM_NAME_ACCESS: Final = {f"{prefix}.name" for prefix in ENUM_BASES} | {
f"{prefix}._name_" for prefix in ENUM_BASES
Expand Down Expand Up @@ -148,65 +156,15 @@ class SomeEnum:
# same value-type, then it doesn't matter which member was passed in.
# The value-type is still known.
if isinstance(ctx.type, Instance):
info = ctx.type.type

# As long as mypy doesn't understand attribute creation in __new__,
# there is no way to predict the value type if the enum class has a
# custom implementation
if _implements_new(info):
return ctx.default_attr_type

stnodes = (info.get(name) for name in info.names)

# Enums _can_ have methods and instance attributes.
# Omit methods and attributes created by assigning to self.*
# for our value inference.
node_types = (
get_proper_type(n.type) if n else None
for n in stnodes
if n is None or not n.implicit
)
proper_types = list(
_infer_value_type_with_auto_fallback(ctx, t)
for t in node_types
if t is None or not isinstance(t, CallableType)
)
underlying_type = _first(proper_types)
if underlying_type is None:
return ctx.default_attr_type

# At first we try to predict future `value` type if all other items
# have the same type. For example, `int`.
# If this is the case, we simply return this type.
# See https://github.com/python/mypy/pull/9443
all_same_value_type = all(
proper_type is not None and proper_type == underlying_type
for proper_type in proper_types
)
if all_same_value_type:
if underlying_type is not None:
return underlying_type

# But, after we started treating all `Enum` values as `Final`,
# we start to infer types in
# `item = 1` as `Literal[1]`, not just `int`.
# So, for example types in this `Enum` will all be different:
#
# class Ordering(IntEnum):
# one = 1
# two = 2
# three = 3
#
# We will infer three `Literal` types here.
# They are not the same, but they are equivalent.
# So, we unify them to make sure `.value` prediction still works.
# Result will be `Literal[1] | Literal[2] | Literal[3]` for this case.
all_equivalent_types = all(
proper_type is not None and is_equivalent(proper_type, underlying_type)
for proper_type in proper_types
)
if all_equivalent_types:
return make_simplified_union(cast(Sequence[Type], proper_types))
return _infer_enum_value_type(ctx.type.type, ctx)
elif isinstance(ctx.type, UnionType):
union_items = []
for item in ctx.type.items:
proper_item = get_proper_type(item)
if not isinstance(proper_item, Instance):
return ctx.default_attr_type
union_items.append(_infer_enum_value_type(proper_item.type, ctx))
return make_simplified_union(union_items)
return ctx.default_attr_type

assert isinstance(ctx.type, Instance)
Expand Down Expand Up @@ -256,3 +214,61 @@ def _extract_underlying_field_name(typ: Type) -> str | None:
# as a string.
assert isinstance(underlying_literal.value, str)
return underlying_literal.value


def _infer_enum_value_type(info: TypeInfo, ctx: mypy.plugin.AttributeContext) -> Type:
# As long as mypy doesn't understand attribute creation in __new__,
# there is no way to predict the value type if the enum class has a
# custom implementation
if _implements_new(info):
return ctx.default_attr_type

stnodes = (info.get(name) for name in info.names)

# Enums _can_ have methods and instance attributes.
# Omit methods and attributes created by assigning to self.*
# for our value inference.
node_types = (
get_proper_type(n.type) if n else None for n in stnodes if n is None or not n.implicit
)
proper_types = list(
_infer_value_type_with_auto_fallback(ctx, t)
for t in node_types
if t is None or not isinstance(t, CallableType)
)
underlying_type = _first(proper_types)
if underlying_type is None:
return ctx.default_attr_type

# At first we try to predict future `value` type if all other items
# have the same type. For example, `int`.
# If this is the case, we simply return this type.
# See https://github.com/python/mypy/pull/9443
all_same_value_type = all(
proper_type is not None and proper_type == underlying_type for proper_type in proper_types
)
if all_same_value_type:
if underlying_type is not None:
return underlying_type

# But, after we started treating all `Enum` values as `Final`,
# we start to infer types in
# `item = 1` as `Literal[1]`, not just `int`.
# So, for example types in this `Enum` will all be different:
#
# class Ordering(IntEnum):
# one = 1
# two = 2
# three = 3
#
# We will infer three `Literal` types here.
# They are not the same, but they are equivalent.
# So, we unify them to make sure `.value` prediction still works.
# Result will be `Literal[1] | Literal[2] | Literal[3]` for this case.
all_equivalent_types = all(
proper_type is not None and is_equivalent(proper_type, underlying_type)
for proper_type in proper_types
)
if all_equivalent_types:
return make_simplified_union(cast(Sequence[Type], proper_types))
return ctx.default_attr_type
53 changes: 53 additions & 0 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,59 @@ def infer_truth(truth: Truth) -> None:
reveal_type(truth.value) # N: Revealed type is "builtins.bool"
[builtins fixtures/bool.pyi]

[case testEnumNameAndValueForUnion]
from typing import Union
from enum import Enum
class E1(Enum):
A = 1
B = 2
class E2(Enum):
A = 3
B = 4
e: Union[E1, E2]
reveal_type(e.value) # N: Revealed type is "Union[Literal[1]?, Literal[2]?, Literal[3]?, Literal[4]?]"
[builtins fixtures/bool.pyi]

[case testEnumNameAndValueForUnionOfHeterogenousEnums]
from typing import Union
from enum import Enum
class E1(Enum):
A = 1
B = 'b'
class E2(Enum):
C = 3
D = 'd'
e: Union[E1, E2]
reveal_type(e.value) # N: Revealed type is "Any"
[builtins fixtures/bool.pyi]

[case testEnumNameAndValueForHeterogenousUnionOfHomogenousEnums]
from typing import Union
from enum import Enum
class E1(Enum):
A = 1
B = 2
class E2(Enum):
C = 'c'
D = 'd'
e: Union[E1, E2]
reveal_type(e.value) # N: Revealed type is "Union[Literal[1]?, Literal[2]?, Literal['c']?, Literal['d']?]"
[builtins fixtures/bool.pyi]

[case testEnumNameAndValueForEnumsWithNonInstanceItems]
from typing import Union
from typing_extensions import Literal
from enum import Enum
class E1(Enum):
A = 1
B = 2
class E2(Enum):
C = 3
D = 4
e: Union[E1, Literal[E2.C]]
reveal_type(e.value) # N: Revealed type is "Any"
[builtins fixtures/bool.pyi]

[case testEnumUnique]
import enum
@enum.unique
Expand Down