Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow None vs TypeVar overlap for overloads #15846

Merged
merged 6 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7216,22 +7216,32 @@ def is_unsafe_overlapping_overload_signatures(
#
# This discrepancy is unfortunately difficult to get rid of, so we repeat the
# checks twice in both directions for now.
#
# Note that we ignore possible overlap between type variables and None. This
# is technically unsafe, but unsafety is tiny and this prevents some common
# use cases like:
# @overload
# def foo(x: None) -> None: ..
# @overload
# def foo(x: T) -> Foo[T]: ...
return is_callable_compatible(
signature,
other,
is_compat=is_overlapping_types_no_promote_no_uninhabited,
is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none,
is_compat_return=lambda l, r: not is_subtype_no_promote(l, r),
ignore_return=False,
check_args_covariantly=True,
allow_partial_overlap=True,
no_unify_none=True,
) or is_callable_compatible(
other,
signature,
is_compat=is_overlapping_types_no_promote_no_uninhabited,
is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none,
is_compat_return=lambda l, r: not is_subtype_no_promote(r, l),
ignore_return=False,
check_args_covariantly=False,
allow_partial_overlap=True,
no_unify_none=True,
)


Expand Down Expand Up @@ -7717,12 +7727,18 @@ def is_subtype_no_promote(left: Type, right: Type) -> bool:
return is_subtype(left, right, ignore_promotions=True)


def is_overlapping_types_no_promote_no_uninhabited(left: Type, right: Type) -> bool:
def is_overlapping_types_no_promote_no_uninhabited_no_none(left: Type, right: Type) -> bool:
# For the purpose of unsafe overload checks we consider list[<nothing>] and list[int]
# non-overlapping. This is consistent with how we treat list[int] and list[str] as
# non-overlapping, despite [] belongs to both. Also this will prevent false positives
# for failed type inference during unification.
return is_overlapping_types(left, right, ignore_promotions=True, ignore_uninhabited=True)
return is_overlapping_types(
left,
right,
ignore_promotions=True,
ignore_uninhabited=True,
prohibit_none_typevar_overlap=True,
)


def is_private(node_name: str) -> bool:
Expand Down
86 changes: 69 additions & 17 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2400,6 +2400,11 @@ def check_overload_call(
# typevar. See https://github.com/python/mypy/issues/4063 for related discussion.
erased_targets: list[CallableType] | None = None
unioned_result: tuple[Type, Type] | None = None

# Determine whether we need to encourage union math. This should be generally safe,
# as union math infers better results in the vast majority of cases, but it is very
# computationally intensive.
none_type_var_overlap = self.possible_none_type_var_overlap(arg_types, plausible_targets)
union_interrupted = False # did we try all union combinations?
if any(self.real_union(arg) for arg in arg_types):
try:
Expand All @@ -2412,6 +2417,7 @@ def check_overload_call(
arg_names,
callable_name,
object_type,
none_type_var_overlap,
context,
)
except TooManyUnions:
Expand Down Expand Up @@ -2444,8 +2450,10 @@ def check_overload_call(
# If any of checks succeed, stop early.
if inferred_result is not None and unioned_result is not None:
# Both unioned and direct checks succeeded, choose the more precise type.
if is_subtype(inferred_result[0], unioned_result[0]) and not isinstance(
get_proper_type(inferred_result[0]), AnyType
if (
is_subtype(inferred_result[0], unioned_result[0])
and not isinstance(get_proper_type(inferred_result[0]), AnyType)
and not none_type_var_overlap
):
return inferred_result
return unioned_result
Expand Down Expand Up @@ -2495,7 +2503,8 @@ def check_overload_call(
callable_name=callable_name,
object_type=object_type,
)
if union_interrupted:
# Do not show the extra error if the union math was forced.
if union_interrupted and not none_type_var_overlap:
self.chk.fail(message_registry.TOO_MANY_UNION_COMBINATIONS, context)
return result

Expand Down Expand Up @@ -2650,6 +2659,44 @@ def overload_erased_call_targets(
matches.append(typ)
return matches

def possible_none_type_var_overlap(
self, arg_types: list[Type], plausible_targets: list[CallableType]
) -> bool:
"""Heuristic to determine whether we need to try forcing union math.

This is needed to avoid greedy type variable match in situations like this:
@overload
def foo(x: None) -> None: ...
@overload
def foo(x: T) -> list[T]: ...

x: int | None
foo(x)
we want this call to infer list[int] | None, not list[int | None].
"""
if not plausible_targets or not arg_types:
return False
has_optional_arg = False
for arg_type in get_proper_types(arg_types):
if not isinstance(arg_type, UnionType):
continue
for item in get_proper_types(arg_type.items):
if isinstance(item, NoneType):
has_optional_arg = True
break
if not has_optional_arg:
return False

min_prefix = min(len(c.arg_types) for c in plausible_targets)
for i in range(min_prefix):
if any(
isinstance(get_proper_type(c.arg_types[i]), NoneType) for c in plausible_targets
) and any(
isinstance(get_proper_type(c.arg_types[i]), TypeVarType) for c in plausible_targets
):
return True
return False

def union_overload_result(
self,
plausible_targets: list[CallableType],
Expand All @@ -2659,6 +2706,7 @@ def union_overload_result(
arg_names: Sequence[str | None] | None,
callable_name: str | None,
object_type: Type | None,
none_type_var_overlap: bool,
context: Context,
level: int = 0,
) -> list[tuple[Type, Type]] | None:
Expand Down Expand Up @@ -2698,20 +2746,23 @@ def union_overload_result(

# Step 3: Try a direct match before splitting to avoid unnecessary union splits
# and save performance.
with self.type_overrides_set(args, arg_types):
direct = self.infer_overload_return_type(
plausible_targets,
args,
arg_types,
arg_kinds,
arg_names,
callable_name,
object_type,
context,
)
if direct is not None and not isinstance(get_proper_type(direct[0]), (UnionType, AnyType)):
# We only return non-unions soon, to avoid greedy match.
return [direct]
if not none_type_var_overlap:
with self.type_overrides_set(args, arg_types):
direct = self.infer_overload_return_type(
plausible_targets,
args,
arg_types,
arg_kinds,
arg_names,
callable_name,
object_type,
context,
)
if direct is not None and not isinstance(
get_proper_type(direct[0]), (UnionType, AnyType)
):
# We only return non-unions soon, to avoid greedy match.
return [direct]

# Step 4: Split the first remaining union type in arguments into items and
# try to match each item individually (recursive).
Expand All @@ -2729,6 +2780,7 @@ def union_overload_result(
arg_names,
callable_name,
object_type,
none_type_var_overlap,
context,
level + 1,
)
Expand Down
15 changes: 13 additions & 2 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,7 @@ def is_callable_compatible(
check_args_covariantly: bool = False,
allow_partial_overlap: bool = False,
strict_concatenate: bool = False,
no_unify_none: bool = False,
) -> bool:
"""Is the left compatible with the right, using the provided compatibility check?

Expand Down Expand Up @@ -1415,7 +1416,9 @@ def g(x: int) -> int: ...
# (below) treats type variables on the two sides as independent.
if left.variables:
# Apply generic type variables away in left via type inference.
unified = unify_generic_callable(left, right, ignore_return=ignore_return)
unified = unify_generic_callable(
left, right, ignore_return=ignore_return, no_unify_none=no_unify_none
)
if unified is None:
return False
left = unified
Expand All @@ -1427,7 +1430,9 @@ def g(x: int) -> int: ...
# So, we repeat the above checks in the opposite direction. This also
# lets us preserve the 'symmetry' property of allow_partial_overlap.
if allow_partial_overlap and right.variables:
unified = unify_generic_callable(right, left, ignore_return=ignore_return)
unified = unify_generic_callable(
right, left, ignore_return=ignore_return, no_unify_none=no_unify_none
)
if unified is not None:
right = unified

Expand Down Expand Up @@ -1687,6 +1692,8 @@ def unify_generic_callable(
target: NormalizedCallableType,
ignore_return: bool,
return_constraint_direction: int | None = None,
*,
no_unify_none: bool = False,
) -> NormalizedCallableType | None:
"""Try to unify a generic callable type with another callable type.

Expand All @@ -1708,6 +1715,10 @@ def unify_generic_callable(
type.ret_type, target.ret_type, return_constraint_direction
)
constraints.extend(c)
if no_unify_none:
constraints = [
c for c in constraints if not isinstance(get_proper_type(c.target), NoneType)
]
inferred_vars, _ = mypy.solve.solve_constraints(type.variables, constraints)
if None in inferred_vars:
return None
Expand Down
39 changes: 33 additions & 6 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -2185,36 +2185,63 @@ def bar2(*x: int) -> int: ...
[builtins fixtures/tuple.pyi]

[case testOverloadDetectsPossibleMatchesWithGenerics]
from typing import overload, TypeVar, Generic
# flags: --strict-optional
from typing import overload, TypeVar, Generic, Optional, List

T = TypeVar('T')
# The examples below are unsafe, but it is a quite common pattern
# so we ignore the possibility of type variables taking value `None`
# for the purpose of overload overlap checks.

@overload
def foo(x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def foo(x: None, y: None) -> str: ...
@overload
def foo(x: T, y: T) -> int: ...
def foo(x): ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test also calling this to make sure we infer the return type correctly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh well, this discovered an obvious flaw in my solution, for this

@overload
def foo(x: None) -> None: ...
@overload
def foo(x: T) -> list[T]: ...

x: int | None
foo(x)

we will continue to infer list[int | None], not list[int] | None that people will likely expect. I will try to check if there is a simple way to force union-math for this case. If there are none, this PR may become much more complex.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I think I have come up with some heuristics to force union math for None vs TypeVar overlap situations. It is not very precise, but should be generally safe, since IIUC union math infer better types anyway, we usually try to avoid it simply because it is computationally intensive.


oi: Optional[int]
reveal_type(foo(None, None)) # N: Revealed type is "builtins.str"
reveal_type(foo(None, 42)) # N: Revealed type is "builtins.int"
reveal_type(foo(42, 42)) # N: Revealed type is "builtins.int"
reveal_type(foo(oi, None)) # N: Revealed type is "Union[builtins.int, builtins.str]"
reveal_type(foo(oi, 42)) # N: Revealed type is "builtins.int"
reveal_type(foo(oi, oi)) # N: Revealed type is "Union[builtins.int, builtins.str]"

@overload
def foo_list(x: None) -> None: ...
@overload
def foo_list(x: T) -> List[T]: ...
def foo_list(x): ...

reveal_type(foo_list(oi)) # N: Revealed type is "Union[builtins.list[builtins.int], None]"

# What if 'T' is 'object'?
@overload
def bar(x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def bar(x: None, y: int) -> str: ...
@overload
def bar(x: T, y: T) -> int: ...
def bar(x, y): ...

class Wrapper(Generic[T]):
@overload
def foo(self, x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def foo(self, x: None, y: None) -> str: ...
@overload
def foo(self, x: T, y: None) -> int: ...
def foo(self, x): ...

@overload
def bar(self, x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def bar(self, x: None, y: int) -> str: ...
@overload
def bar(self, x: T, y: T) -> int: ...
def bar(self, x, y): ...

@overload
def baz(x: str, y: str) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
@overload
def baz(x: T, y: T) -> int: ...
def baz(x): ...
[builtins fixtures/tuple.pyi]

[case testOverloadFlagsPossibleMatches]
from wrapper import *
[file wrapper.pyi]
Expand Down Expand Up @@ -3996,7 +4023,7 @@ T = TypeVar('T')

class FakeAttribute(Generic[T]):
@overload
def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ...
@overload
def dummy(self, instance: T, owner: Type[T]) -> int: ...
def dummy(self, instance: Optional[T], owner: Type[T]) -> Union['FakeAttribute[T]', int]: ...
Expand Down