Skip to content

Commit

Permalink
Add heuristic to enforce union math for None vs TypeVar overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan Levkivskyi committed Aug 11, 2023
1 parent 494016a commit 05eb6ff
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 17 deletions.
81 changes: 65 additions & 16 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2400,6 +2400,11 @@ def check_overload_call(
# typevar. See https://github.com/python/mypy/issues/4063 for related discussion.
erased_targets: list[CallableType] | None = None
unioned_result: tuple[Type, Type] | None = None

# Determine whether we need to encourage union math. This should be generally safe,
# as union math infers better results in the vast majority of cases, but it is very
# computationally intensive.
none_type_var_overlap = self.possible_none_type_var_overlap(arg_types, plausible_targets)
union_interrupted = False # did we try all union combinations?
if any(self.real_union(arg) for arg in arg_types):
try:
Expand All @@ -2412,6 +2417,7 @@ def check_overload_call(
arg_names,
callable_name,
object_type,
none_type_var_overlap,
context,
)
except TooManyUnions:
Expand Down Expand Up @@ -2444,8 +2450,10 @@ def check_overload_call(
# If any of checks succeed, stop early.
if inferred_result is not None and unioned_result is not None:
# Both unioned and direct checks succeeded, choose the more precise type.
if is_subtype(inferred_result[0], unioned_result[0]) and not isinstance(
get_proper_type(inferred_result[0]), AnyType
if (
is_subtype(inferred_result[0], unioned_result[0])
and not isinstance(get_proper_type(inferred_result[0]), AnyType)
and not none_type_var_overlap
):
return inferred_result
return unioned_result
Expand Down Expand Up @@ -2650,6 +2658,42 @@ def overload_erased_call_targets(
matches.append(typ)
return matches

def possible_none_type_var_overlap(
self, arg_types: list[Type], plausible_targets: list[CallableType]
) -> bool:
"""Heuristic to determine whether we need to try forcing union math.
This is needed to avoid greedy type variable match in situations like this:
@overload
def foo(x: None) -> None: ...
@overload
def foo(x: T) -> list[T]: ...
x: int | None
foo(x)
we want this call to infer list[int] | None, not list[int | None].
"""
has_optional_arg = False
for arg_type in get_proper_types(arg_types):
if not isinstance(arg_type, UnionType):
continue
for item in get_proper_types(arg_type.items):
if isinstance(item, NoneType):
has_optional_arg = True
break
if not has_optional_arg:
return False

min_prefix = min(len(c.arg_types) for c in plausible_targets)
for i in range(min_prefix):
if any(
isinstance(get_proper_type(c.arg_types[i]), NoneType) for c in plausible_targets
) and any(
isinstance(get_proper_type(c.arg_types[i]), TypeVarType) for c in plausible_targets
):
return True
return False

def union_overload_result(
self,
plausible_targets: list[CallableType],
Expand All @@ -2659,6 +2703,7 @@ def union_overload_result(
arg_names: Sequence[str | None] | None,
callable_name: str | None,
object_type: Type | None,
none_type_var_overlap: bool,
context: Context,
level: int = 0,
) -> list[tuple[Type, Type]] | None:
Expand Down Expand Up @@ -2698,20 +2743,23 @@ def union_overload_result(

# Step 3: Try a direct match before splitting to avoid unnecessary union splits
# and save performance.
with self.type_overrides_set(args, arg_types):
direct = self.infer_overload_return_type(
plausible_targets,
args,
arg_types,
arg_kinds,
arg_names,
callable_name,
object_type,
context,
)
if direct is not None and not isinstance(get_proper_type(direct[0]), (UnionType, AnyType)):
# We only return non-unions soon, to avoid greedy match.
return [direct]
if not none_type_var_overlap:
with self.type_overrides_set(args, arg_types):
direct = self.infer_overload_return_type(
plausible_targets,
args,
arg_types,
arg_kinds,
arg_names,
callable_name,
object_type,
context,
)
if direct is not None and not isinstance(
get_proper_type(direct[0]), (UnionType, AnyType)
):
# We only return non-unions soon, to avoid greedy match.
return [direct]

# Step 4: Split the first remaining union type in arguments into items and
# try to match each item individually (recursive).
Expand All @@ -2729,6 +2777,7 @@ def union_overload_result(
arg_names,
callable_name,
object_type,
none_type_var_overlap,
context,
level + 1,
)
Expand Down
20 changes: 19 additions & 1 deletion test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -2185,7 +2185,8 @@ def bar2(*x: int) -> int: ...
[builtins fixtures/tuple.pyi]

[case testOverloadDetectsPossibleMatchesWithGenerics]
from typing import overload, TypeVar, Generic
# flags: --strict-optional
from typing import overload, TypeVar, Generic, Optional, List

T = TypeVar('T')
# The examples below are unsafe, but it is a quite common pattern
Expand All @@ -2198,6 +2199,22 @@ def foo(x: None, y: None) -> str: ...
def foo(x: T, y: T) -> int: ...
def foo(x): ...

oi: Optional[int]
reveal_type(foo(None, None)) # N: Revealed type is "builtins.str"
reveal_type(foo(None, 42)) # N: Revealed type is "builtins.int"
reveal_type(foo(42, 42)) # N: Revealed type is "builtins.int"
reveal_type(foo(oi, None)) # N: Revealed type is "Union[builtins.int, builtins.str]"
reveal_type(foo(oi, 42)) # N: Revealed type is "builtins.int"
reveal_type(foo(oi, oi)) # N: Revealed type is "Union[builtins.int, builtins.str]"

@overload
def foo_list(x: None) -> None: ...
@overload
def foo_list(x: T) -> List[T]: ...
def foo_list(x): ...

reveal_type(foo_list(oi)) # N: Revealed type is "Union[builtins.list[builtins.int], None]"

# What if 'T' is 'object'?
@overload
def bar(x: None, y: int) -> str: ...
Expand All @@ -2223,6 +2240,7 @@ def baz(x: str, y: str) -> str: ... # E: Overloaded function signatures 1 and 2
@overload
def baz(x: T, y: T) -> int: ...
def baz(x): ...
[builtins fixtures/tuple.pyi]

[case testOverloadFlagsPossibleMatches]
from wrapper import *
Expand Down

0 comments on commit 05eb6ff

Please sign in to comment.