Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not ignore generic type args when checking multiple inheritance compatibility #18270

81 changes: 63 additions & 18 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2739,26 +2739,57 @@ def check_protocol_variance(self, defn: ClassDef) -> None:
if expected != tvar.variance:
self.msg.bad_proto_variance(tvar.variance, tvar.name, expected, defn)

def get_parameterized_base_classes(self, typ: TypeInfo) -> list[Instance]:
"""Build an MRO-like structure with generic type args substituted.

Excludes the class itself.

When several bases have a common ancestor, includes an :class:`Instance`
for each param.
"""
bases = []
for parent in typ.mro[1:]:
if parent.is_generic():
for base in typ.bases:
if parent in base.type.mro:
bases.append(map_instance_to_supertype(base, parent))
else:
bases.append(Instance(parent, []))
return bases

def check_multiple_inheritance(self, typ: TypeInfo) -> None:
"""Check for multiple inheritance related errors."""
if len(typ.bases) <= 1:
# No multiple inheritance.
return

# Verify that inherited attributes are compatible.
mro = typ.mro[1:]
for i, base in enumerate(mro):
typed_mro = self.get_parameterized_base_classes(typ)
# If the first MRO entry is compatible with everything following, we don't need
# (and shouldn't) compare further pairs
# (see testMultipleInheritanceExplcitDiamondResolution)
seen_names = set()
for i, base in enumerate(typed_mro):
# Attributes defined in both the type and base are skipped.
# Normal checks for attribute compatibility should catch any problems elsewhere.
non_overridden_attrs = base.names.keys() - typ.names.keys()
# Sort for consistent messages order.
non_overridden_attrs = sorted(typed_mro[i].type.names - typ.names.keys())
for name in non_overridden_attrs:
if is_private(name):
continue
for base2 in mro[i + 1 :]:
if name in seen_names:
continue
for base2 in typed_mro[i + 1 :]:
# We only need to check compatibility of attributes from classes not
# in a subclass relationship. For subclasses, normal (single inheritance)
# checks suffice (these are implemented elsewhere).
if name in base2.names and base2 not in base.mro:
if name in base2.type.names and not is_subtype(
base, base2, ignore_promotions=True
):
# If base1 already inherits from base2 with correct type args,
# we have reported errors if any. Avoid reporting them again.
self.check_compatibility(name, base, base2, typ)
seen_names.add(name)

def determine_type_of_member(self, sym: SymbolTableNode) -> Type | None:
if sym.type is not None:
Expand All @@ -2783,8 +2814,23 @@ def determine_type_of_member(self, sym: SymbolTableNode) -> Type | None:
# TODO: handle more node kinds here.
return None

def attribute_type_from_base(
self, name: str, base: Instance
) -> tuple[ProperType | None, SymbolTableNode]:
"""For a NameExpr that is part of a class, walk all base classes and try
to find the first class that defines a Type for the same name."""
base_var = base.type[name]
base_type = self.determine_type_of_member(base_var)
if base_type is None:
return None, base_var

if not has_no_typevars(base_type):
base_type = expand_type_by_instance(base_type, base)

return get_proper_type(base_type), base_var

def check_compatibility(
self, name: str, base1: TypeInfo, base2: TypeInfo, ctx: TypeInfo
self, name: str, base1: Instance, base2: Instance, ctx: TypeInfo
) -> None:
"""Check if attribute name in base1 is compatible with base2 in multiple inheritance.

Expand All @@ -2809,10 +2855,9 @@ class C(B, A[int]): ... # this is unsafe because...
if name in ("__init__", "__new__", "__init_subclass__"):
# __init__ and friends can be incompatible -- it's a special case.
return
first = base1.names[name]
second = base2.names[name]
first_type = get_proper_type(self.determine_type_of_member(first))
second_type = get_proper_type(self.determine_type_of_member(second))

first_type, first = self.attribute_type_from_base(name, base1)
second_type, second = self.attribute_type_from_base(name, base2)

# TODO: use more principled logic to decide is_subtype() vs is_equivalent().
# We should rely on mutability of superclass node, not on types being Callable.
Expand All @@ -2822,7 +2867,7 @@ class C(B, A[int]): ... # this is unsafe because...
if isinstance(first_type, Instance):
call = find_member("__call__", first_type, first_type, is_operator=True)
if call and isinstance(second_type, FunctionLike):
second_sig = self.bind_and_map_method(second, second_type, ctx, base2)
second_sig = self.bind_and_map_method(second, second_type, ctx, base2.type)
ok = is_subtype(call, second_sig, ignore_pos_arg_names=True)
elif isinstance(first_type, FunctionLike) and isinstance(second_type, FunctionLike):
if first_type.is_type_obj() and second_type.is_type_obj():
Expand All @@ -2834,8 +2879,8 @@ class C(B, A[int]): ... # this is unsafe because...
)
else:
# First bind/map method types when necessary.
first_sig = self.bind_and_map_method(first, first_type, ctx, base1)
second_sig = self.bind_and_map_method(second, second_type, ctx, base2)
first_sig = self.bind_and_map_method(first, first_type, ctx, base1.type)
second_sig = self.bind_and_map_method(second, second_type, ctx, base2.type)
ok = is_subtype(first_sig, second_sig, ignore_pos_arg_names=True)
elif first_type and second_type:
if isinstance(first.node, Var):
Expand All @@ -2844,7 +2889,7 @@ class C(B, A[int]): ... # this is unsafe because...
second_type = expand_self_type(second.node, second_type, fill_typevars(ctx))
ok = is_equivalent(first_type, second_type)
if not ok:
second_node = base2[name].node
second_node = base2.type[name].node
if (
isinstance(second_type, FunctionLike)
and second_node is not None
Expand All @@ -2854,22 +2899,22 @@ class C(B, A[int]): ... # this is unsafe because...
ok = is_subtype(first_type, second_type)
else:
if first_type is None:
self.msg.cannot_determine_type_in_base(name, base1.name, ctx)
self.msg.cannot_determine_type_in_base(name, base1.type.name, ctx)
if second_type is None:
self.msg.cannot_determine_type_in_base(name, base2.name, ctx)
self.msg.cannot_determine_type_in_base(name, base2.type.name, ctx)
ok = True
# Final attributes can never be overridden, but can override
# non-final read-only attributes.
if is_final_node(second.node) and not is_private(name):
self.msg.cant_override_final(name, base2.name, ctx)
self.msg.cant_override_final(name, base2.type.name, ctx)
if is_final_node(first.node):
self.check_if_final_var_override_writable(name, second.node, ctx)
# Some attributes like __slots__ and __deletable__ are special, and the type can
# vary across class hierarchy.
if isinstance(second.node, Var) and second.node.allow_incompatible_override:
ok = True
if not ok:
self.msg.base_class_definitions_incompatible(name, base1, base2, ctx)
self.msg.base_class_definitions_incompatible(name, base1.type, base2.type, ctx)

def check_metaclass_compatibility(self, typ: TypeInfo) -> None:
"""Ensures that metaclasses of all parent types are compatible."""
Expand Down
50 changes: 45 additions & 5 deletions test-data/unit/check-generic-subtyping.test
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ x1: X1[str, int]
reveal_type(list(x1)) # N: Revealed type is "builtins.list[builtins.int]"
reveal_type([*x1]) # N: Revealed type is "builtins.list[builtins.int]"

class X2(Generic[T, U], Iterator[U], Mapping[T, U]):
class X2(Generic[T, U], Iterator[U], Mapping[T, U]): # E: Definition of "__iter__" in base class "Iterable" is incompatible with definition in base class "Iterable"
pass

x2: X2[str, int]
Expand Down Expand Up @@ -1017,10 +1017,7 @@ x1: X1[str, int]
reveal_type(iter(x1)) # N: Revealed type is "typing.Iterator[builtins.int]"
reveal_type({**x1}) # N: Revealed type is "builtins.dict[builtins.int, builtins.str]"

# Some people would expect this to raise an error, but this currently does not:
# `Mapping` has `Iterable[U]` base class, `X2` has direct `Iterable[T]` base class.
# It would be impossible to define correct `__iter__` method for incompatible `T` and `U`.
class X2(Generic[T, U], Mapping[U, T], Iterable[T]):
class X2(Generic[T, U], Mapping[U, T], Iterable[T]): # E: Definition of "__iter__" in base class "Iterable" is incompatible with definition in base class "Iterable"
pass

x2: X2[str, int]
Expand Down Expand Up @@ -1065,3 +1062,46 @@ class F(E[T_co], Generic[T_co]): ... # E: Variance of TypeVar "T_co" incompatib

class G(Generic[T]): ...
class H(G[T_contra], Generic[T_contra]): ... # E: Variance of TypeVar "T_contra" incompatible with variance in parent type

[case testMultipleInheritanceCompatibleTypeVar]
from typing import Generic, TypeVar

T = TypeVar("T")
U = TypeVar("U")

class A(Generic[T]):
x: T
def fn(self, t: T) -> None: ...

class A2(A[T]):
y: str
z: str

class B(Generic[T]):
x: T
def fn(self, t: T) -> None: ...

class C1(A2[str], B[str]): pass
class C2(A2[str], B[int]): pass # E: Definition of "fn" in base class "A" is incompatible with definition in base class "B" \
# E: Definition of "x" in base class "A" is incompatible with definition in base class "B"
class C3(A2[T], B[T]): pass
class C4(A2[U], B[U]): pass
class C5(A2[U], B[T]): pass # E: Definition of "fn" in base class "A" is incompatible with definition in base class "B" \
# E: Definition of "x" in base class "A" is incompatible with definition in base class "B"

[builtins fixtures/tuple.pyi]

[case testMultipleInheritanceNestedTypeVarPropagation]
from typing import Generic, TypeVar

T = TypeVar("T")

class A(Generic[T]):
foo: T
class B(A[str]): ...
class C(B): ...
class D(C): ...

class Bad(D, A[T]): ... # E: Definition of "foo" in base class "A" is incompatible with definition in base class "A"
class Good(D, A[str]): ... # OK
[builtins fixtures/tuple.pyi]
45 changes: 44 additions & 1 deletion test-data/unit/check-multiple-inheritance.test
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,6 @@ class D2(B[Union[int, str]], C2): ...
class D3(C2, B[str]): ...
class D4(B[str], C2): ... # E: Definition of "foo" in base class "A" is incompatible with definition in base class "C2"


[case testMultipleInheritanceOverridingOfFunctionsWithCallableInstances]
from typing import Any, Callable

Expand Down Expand Up @@ -706,3 +705,47 @@ class C34(B3, B4): ...
class C41(B4, B1): ...
class C42(B4, B2): ...
class C43(B4, B3): ...

[case testMultipleInheritanceTransitive]
class A:
def fn(self, x: int) -> None: ...
class B(A): ...
class C(A):
def fn(self, x: "int | str") -> None: ...
class D(B, C): ...

[case testMultipleInheritanceCompatErrorPropagation]
class A:
foo: bytes
class B(A):
foo: str # type: ignore[assignment]

class Ok(B, A): pass

class C(A): pass
class Ok2(B, C): pass

[case testMultipleInheritanceExplcitDiamondResolution]
class A:
class M:
pass

class B0(A):
class M(A.M):
pass

class B1(A):
class M(A.M):
pass

class C(B0,B1):
class M(B0.M, B1.M):
pass

class D0(B0):
pass
class D1(B1):
pass

class D(D0,D1,C):
pass
Loading