diff --git a/mypy/checker.py b/mypy/checker.py index cdf2ab648545..171c45822791 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1940,6 +1940,7 @@ def bind_and_map_method( sub_info: class where the method is used super_info: class where the method was defined """ + mapped_typ = cast(FunctionLike, map_type_from_supertype(typ, sub_info, super_info)) if isinstance(sym.node, (FuncDef, OverloadedFuncDef, Decorator)) and not is_static( sym.node ): @@ -1947,28 +1948,9 @@ def bind_and_map_method( is_class_method = sym.node.func.is_class else: is_class_method = sym.node.is_class - - mapped_typ = cast(FunctionLike, map_type_from_supertype(typ, sub_info, super_info)) active_self_type = self.scope.active_self_type() - if isinstance(mapped_typ, Overloaded) and active_self_type: - # If we have an overload, filter to overloads that match the self type. - # This avoids false positives for concrete subclasses of generic classes, - # see testSelfTypeOverrideCompatibility for an example. - filtered_items = [ - item - for item in mapped_typ.items - if not item.arg_types or is_subtype(active_self_type, item.arg_types[0]) - ] - # If we don't have any filtered_items, maybe it's always a valid override - # of the superclass? However if you get to that point you're in murky type - # territory anyway, so we just preserve the type and have the behaviour match - # that of older versions of mypy. - if filtered_items: - mapped_typ = Overloaded(filtered_items) - return bind_self(mapped_typ, active_self_type, is_class_method) - else: - return cast(FunctionLike, map_type_from_supertype(typ, sub_info, super_info)) + return mapped_typ def get_op_other_domain(self, tp: FunctionLike) -> Type | None: if isinstance(tp, CallableType): diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 9c6518b9e487..cbd5e731e3e7 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -86,6 +86,7 @@ def __init__( # Supported for both proper and non-proper ignore_promotions: bool = False, ignore_uninhabited: bool = False, + ignore_type_vars: bool = False, # Proper subtype flags erase_instances: bool = False, keep_erased_types: bool = False, @@ -96,6 +97,7 @@ def __init__( self.ignore_declared_variance = ignore_declared_variance self.ignore_promotions = ignore_promotions self.ignore_uninhabited = ignore_uninhabited + self.ignore_type_vars = ignore_type_vars self.erase_instances = erase_instances self.keep_erased_types = keep_erased_types self.options = options @@ -119,6 +121,7 @@ def is_subtype( ignore_declared_variance: bool = False, ignore_promotions: bool = False, ignore_uninhabited: bool = False, + ignore_type_vars: bool = False, options: Options | None = None, ) -> bool: """Is 'left' subtype of 'right'? @@ -139,6 +142,7 @@ def is_subtype( ignore_declared_variance=ignore_declared_variance, ignore_promotions=ignore_promotions, ignore_uninhabited=ignore_uninhabited, + ignore_type_vars=ignore_type_vars, options=options, ) else: @@ -287,6 +291,11 @@ def _is_subtype( # ErasedType as we do for non-proper subtyping. return True + if subtype_context.ignore_type_vars and ( + isinstance(left, TypeVarType) or isinstance(right, TypeVarType) + ): + return True + if isinstance(right, UnionType) and not isinstance(left, UnionType): # Normally, when 'left' is not itself a union, the only way # 'left' can be a subtype of the union 'right' is if it is a diff --git a/mypy/typeops.py b/mypy/typeops.py index 8c01fb118076..f0d40a61b29a 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -143,7 +143,9 @@ def type_object_type_from_function( # ... # # We need to map B's __init__ to the type (List[T]) -> None. - signature = bind_self(signature, original_type=default_self, is_classmethod=is_new) + signature = bind_self( + signature, original_type=default_self, is_classmethod=is_new, selftypes=orig_self_types + ) signature = cast(FunctionLike, map_type_from_supertype(signature, info, def_info)) special_sig: str | None = None @@ -251,7 +253,12 @@ def supported_self_type(typ: ProperType) -> bool: F = TypeVar("F", bound=FunctionLike) -def bind_self(method: F, original_type: Type | None = None, is_classmethod: bool = False) -> F: +def bind_self( + method: F, + original_type: Type | None = None, + is_classmethod: bool = False, + selftypes: list[Type | None] | None = None, +) -> F: """Return a copy of `method`, with the type of its first parameter (usually self or cls) bound to original_type. @@ -274,10 +281,36 @@ class B(A): pass b = B().copy() # type: B """ + + from mypy.subtypes import is_subtype + if isinstance(method, Overloaded): + # Try to remove overload items with non-matching self types first (fixes #14943) + origtype = get_proper_type(original_type) + if isinstance(origtype, Instance): + methoditems = [] + if selftypes is not None: + selftypes_copy = selftypes.copy() + selftypes.clear() + for idx, methoditem in enumerate(method.items): + selftype = get_self_type(methoditem, origtype) + selftype_proper = get_proper_type(selftype) + if not isinstance(selftype_proper, Instance) or is_subtype( + origtype, selftype_proper, ignore_type_vars=True + ): + methoditems.append(methoditem) + if selftypes is not None: + selftypes.append(selftypes_copy[idx]) + if len(methoditems) == 0: + methoditems = method.items + if selftypes is not None: + selftypes.extend(selftypes_copy) + else: + methoditems = method.items return cast( - F, Overloaded([bind_self(c, original_type, is_classmethod) for c in method.items]) + F, Overloaded([bind_self(mi, original_type, is_classmethod) for mi in methoditems]) ) + assert isinstance(method, CallableType) func = method if not func.arg_types: diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index 182745b99e40..7c7d352d9b22 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -4020,3 +4020,42 @@ class P(Protocol): [file lib.py] class C: ... + +[case TestOverloadedMethodWithExplictSelfTypes] +from typing import Generic, overload, Protocol, TypeVar, Union + +AnyStr = TypeVar("AnyStr", str, bytes) +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + +class SupportsRead(Protocol[T_co]): + def read(self) -> T_co: ... + +class SupportsWrite(Protocol[T_contra]): + def write(self, s: T_contra) -> int: ... + +class Input(Generic[AnyStr]): + def read(self) -> AnyStr: ... + +class Output(Generic[AnyStr]): + @overload + def write(self: Output[str], s: str) -> int: ... + @overload + def write(self: Output[bytes], s: bytes) -> int: ... + def write(self, s: Union[str, bytes]) -> int: ... + +def f(src: SupportsRead[AnyStr], dst: SupportsWrite[AnyStr]) -> None: ... + +def g1(a: Input[bytes], b: Output[bytes]) -> None: + f(a, b) + +def g2(a: Input[bytes], b: Output[bytes]) -> None: + f(a, b) + +def g3(a: Input[str], b: Output[bytes]) -> None: + f(a, b) # E: Cannot infer type argument 1 of "f" + +def g4(a: Input[bytes], b: Output[str]) -> None: + f(a, b) # E: Cannot infer type argument 1 of "f" + +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-selftype.test b/test-data/unit/check-selftype.test index 8083aaf7cf38..53c24584cb73 100644 --- a/test-data/unit/check-selftype.test +++ b/test-data/unit/check-selftype.test @@ -232,7 +232,7 @@ class C(A[None]): # N: def f(self, s: int) -> int [builtins fixtures/tuple.pyi] -[case testSelfTypeOverrideCompatibilityTypeVar-xfail] +[case testSelfTypeOverrideCompatibilityTypeVar] from typing import overload, TypeVar, Union AT = TypeVar("AT", bound="A") @@ -266,6 +266,26 @@ class B(A): def f(*a, **kw): ... [builtins fixtures/dict.pyi] +[case testSelfTypeOverrideCompatibilitySelfTypeVar] +from typing import Any, Generic, Self, TypeVar, overload + +T_co = TypeVar('T_co', covariant=True) + +class Config(Generic[T_co]): + @overload + def get(self, instance: None) -> Self: ... + @overload + def get(self, instance: Any) -> T_co: ... + def get(self, *a, **kw): ... + +class MultiConfig(Config[T_co]): + @overload + def get(self, instance: None) -> Self: ... + @overload + def get(self, instance: Any) -> T_co: ... + def get(self, *a, **kw): ... +[builtins fixtures/dict.pyi] + [case testSelfTypeSuper] from typing import TypeVar, cast