diff --git a/.github/workflows/mypy_primer.yml b/.github/workflows/mypy_primer.yml index ee868484751e..532e77a0cacb 100644 --- a/.github/workflows/mypy_primer.yml +++ b/.github/workflows/mypy_primer.yml @@ -67,6 +67,7 @@ jobs: --debug \ --additional-flags="--debug-serialize" \ --output concise \ + --show-speed-regression \ | tee diff_${{ matrix.shard-index }}.txt ) || [ $? -eq 1 ] - if: ${{ matrix.shard-index == 0 }} diff --git a/mypy/checker.py b/mypy/checker.py index 9c389cccd95f..2612bcc1defb 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -13,6 +13,7 @@ from mypy import errorcodes as codes, join, message_registry, nodes, operators from mypy.binder import ConditionalTypeBinder, Frame, get_declaration from mypy.checker_shared import CheckerScope, TypeCheckerSharedApi, TypeRange +from mypy.checker_state import checker_state from mypy.checkmember import ( MemberContext, analyze_class_attribute_access, @@ -453,7 +454,7 @@ def check_first_pass(self) -> None: Deferred functions will be processed by check_second_pass(). """ self.recurse_into_functions = True - with state.strict_optional_set(self.options.strict_optional): + with state.strict_optional_set(self.options.strict_optional), checker_state.set(self): self.errors.set_file( self.path, self.tree.fullname, scope=self.tscope, options=self.options ) @@ -494,7 +495,7 @@ def check_second_pass( This goes through deferred nodes, returning True if there were any. """ self.recurse_into_functions = True - with state.strict_optional_set(self.options.strict_optional): + with state.strict_optional_set(self.options.strict_optional), checker_state.set(self): if not todo and not self.deferred_nodes: return False self.errors.set_file( diff --git a/mypy/checker_state.py b/mypy/checker_state.py new file mode 100644 index 000000000000..9b988ad18ba4 --- /dev/null +++ b/mypy/checker_state.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Final + +from mypy.checker_shared import TypeCheckerSharedApi + +# This is global mutable state. Don't add anything here unless there's a very +# good reason. + + +class TypeCheckerState: + # Wrap this in a class since it's faster that using a module-level attribute. + + def __init__(self, type_checker: TypeCheckerSharedApi | None) -> None: + # Value varies by file being processed + self.type_checker = type_checker + + @contextmanager + def set(self, value: TypeCheckerSharedApi) -> Iterator[None]: + saved = self.type_checker + self.type_checker = value + try: + yield + finally: + self.type_checker = saved + + +checker_state: Final = TypeCheckerState(type_checker=None) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index cc104fed0752..b89452d90392 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -97,6 +97,7 @@ def __init__( is_self: bool = False, rvalue: Expression | None = None, suppress_errors: bool = False, + preserve_type_var_ids: bool = False, ) -> None: self.is_lvalue = is_lvalue self.is_super = is_super @@ -113,6 +114,10 @@ def __init__( assert is_lvalue self.rvalue = rvalue self.suppress_errors = suppress_errors + # This attribute is only used to preserve old protocol member access logic. + # It is needed to avoid infinite recursion in cases involving self-referential + # generic methods, see find_member() for details. Do not use for other purposes! + self.preserve_type_var_ids = preserve_type_var_ids def named_type(self, name: str) -> Instance: return self.chk.named_type(name) @@ -143,6 +148,7 @@ def copy_modified( no_deferral=self.no_deferral, rvalue=self.rvalue, suppress_errors=self.suppress_errors, + preserve_type_var_ids=self.preserve_type_var_ids, ) if self_type is not None: mx.self_type = self_type @@ -232,8 +238,6 @@ def analyze_member_access( def _analyze_member_access( name: str, typ: Type, mx: MemberContext, override_info: TypeInfo | None = None ) -> Type: - # TODO: This and following functions share some logic with subtypes.find_member; - # consider refactoring. typ = get_proper_type(typ) if isinstance(typ, Instance): return analyze_instance_member_access(name, typ, mx, override_info) @@ -358,7 +362,8 @@ def analyze_instance_member_access( return AnyType(TypeOfAny.special_form) assert isinstance(method.type, Overloaded) signature = method.type - signature = freshen_all_functions_type_vars(signature) + if not mx.preserve_type_var_ids: + signature = freshen_all_functions_type_vars(signature) if not method.is_static: if isinstance(method, (FuncDef, OverloadedFuncDef)) and method.is_trivial_self: signature = bind_self_fast(signature, mx.self_type) @@ -943,7 +948,8 @@ def analyze_var( def expand_without_binding( typ: Type, var: Var, itype: Instance, original_itype: Instance, mx: MemberContext ) -> Type: - typ = freshen_all_functions_type_vars(typ) + if not mx.preserve_type_var_ids: + typ = freshen_all_functions_type_vars(typ) typ = expand_self_type_if_needed(typ, mx, var, original_itype) expanded = expand_type_by_instance(typ, itype) freeze_all_type_vars(expanded) @@ -958,7 +964,8 @@ def expand_and_bind_callable( mx: MemberContext, is_trivial_self: bool, ) -> Type: - functype = freshen_all_functions_type_vars(functype) + if not mx.preserve_type_var_ids: + functype = freshen_all_functions_type_vars(functype) typ = get_proper_type(expand_self_type(var, functype, mx.original_type)) assert isinstance(typ, FunctionLike) if is_trivial_self: @@ -1056,10 +1063,12 @@ def f(self: S) -> T: ... return functype else: selfarg = get_proper_type(item.arg_types[0]) - # This level of erasure matches the one in checker.check_func_def(), - # better keep these two checks consistent. - if subtypes.is_subtype( + # This matches similar special-casing in bind_self(), see more details there. + self_callable = name == "__call__" and isinstance(selfarg, CallableType) + if self_callable or subtypes.is_subtype( dispatched_arg_type, + # This level of erasure matches the one in checker.check_func_def(), + # better keep these two checks consistent. erase_typevars(erase_to_bound(selfarg)), # This is to work around the fact that erased ParamSpec and TypeVarTuple # callables are not always compatible with non-erased ones both ways. @@ -1220,9 +1229,6 @@ def analyze_class_attribute_access( is_classmethod = (is_decorated and cast(Decorator, node.node).func.is_class) or ( isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_class ) - is_staticmethod = (is_decorated and cast(Decorator, node.node).func.is_static) or ( - isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_static - ) t = get_proper_type(t) is_trivial_self = False if isinstance(node.node, Decorator): @@ -1236,8 +1242,7 @@ def analyze_class_attribute_access( t, isuper, is_classmethod, - is_staticmethod, - mx.self_type, + mx, original_vars=original_vars, is_trivial_self=is_trivial_self, ) @@ -1372,8 +1377,7 @@ def add_class_tvars( t: ProperType, isuper: Instance | None, is_classmethod: bool, - is_staticmethod: bool, - original_type: Type, + mx: MemberContext, original_vars: Sequence[TypeVarLikeType] | None = None, is_trivial_self: bool = False, ) -> Type: @@ -1392,9 +1396,6 @@ class B(A[str]): pass isuper: Current instance mapped to the superclass where method was defined, this is usually done by map_instance_to_supertype() is_classmethod: True if this method is decorated with @classmethod - is_staticmethod: True if this method is decorated with @staticmethod - original_type: The value of the type B in the expression B.foo() or the corresponding - component in case of a union (this is used to bind the self-types) original_vars: Type variables of the class callable on which the method was accessed is_trivial_self: if True, we can use fast path for bind_self(). Returns: @@ -1416,14 +1417,14 @@ class B(A[str]): pass # (i.e. appear in the return type of the class object on which the method was accessed). if isinstance(t, CallableType): tvars = original_vars if original_vars is not None else [] - t = freshen_all_functions_type_vars(t) + if not mx.preserve_type_var_ids: + t = freshen_all_functions_type_vars(t) if is_classmethod: if is_trivial_self: - t = bind_self_fast(t, original_type) + t = bind_self_fast(t, mx.self_type) else: - t = bind_self(t, original_type, is_classmethod=True) - if is_classmethod or is_staticmethod: - assert isuper is not None + t = bind_self(t, mx.self_type, is_classmethod=True) + if isuper is not None: t = expand_type_by_instance(t, isuper) freeze_all_type_vars(t) return t.copy_modified(variables=list(tvars) + list(t.variables)) @@ -1432,14 +1433,7 @@ class B(A[str]): pass [ cast( CallableType, - add_class_tvars( - item, - isuper, - is_classmethod, - is_staticmethod, - original_type, - original_vars=original_vars, - ), + add_class_tvars(item, isuper, is_classmethod, mx, original_vars=original_vars), ) for item in t.items ] diff --git a/mypy/messages.py b/mypy/messages.py index 2e07d7f63498..d18e9917a095 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -2220,8 +2220,13 @@ def report_protocol_problems( exp = get_proper_type(exp) got = get_proper_type(got) setter_suffix = " setter type" if is_lvalue else "" - if not isinstance(exp, (CallableType, Overloaded)) or not isinstance( - got, (CallableType, Overloaded) + if ( + not isinstance(exp, (CallableType, Overloaded)) + or not isinstance(got, (CallableType, Overloaded)) + # If expected type is a type object, it means it is a nested class. + # Showing constructor signature in errors would be confusing in this case, + # since we don't check the signature, only subclassing of type objects. + or exp.is_type_obj() ): self.note( "{}: expected{} {}, got {}".format( diff --git a/mypy/plugin.py b/mypy/plugin.py index 39841d5b907a..de075866d613 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -119,14 +119,13 @@ class C: pass from __future__ import annotations from abc import abstractmethod -from typing import Any, Callable, NamedTuple, TypeVar +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar from mypy_extensions import mypyc_attr, trait from mypy.errorcodes import ErrorCode from mypy.lookup import lookup_fully_qualified from mypy.message_registry import ErrorMessage -from mypy.messages import MessageBuilder from mypy.nodes import ( ArgKind, CallExpr, @@ -138,7 +137,6 @@ class C: pass TypeInfo, ) from mypy.options import Options -from mypy.tvar_scope import TypeVarLikeScope from mypy.types import ( CallableType, FunctionLike, @@ -149,6 +147,10 @@ class C: pass UnboundType, ) +if TYPE_CHECKING: + from mypy.messages import MessageBuilder + from mypy.tvar_scope import TypeVarLikeScope + @trait class TypeAnalyzerPluginInterface: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 84fda7955d75..8d72e44d0eda 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -8,6 +8,7 @@ import mypy.applytype import mypy.constraints import mypy.typeops +from mypy.checker_state import checker_state from mypy.erasetype import erase_type from mypy.expandtype import ( expand_self_type, @@ -26,6 +27,7 @@ COVARIANT, INVARIANT, VARIANCE_NOT_READY, + Context, Decorator, FuncBase, OverloadedFuncDef, @@ -717,8 +719,7 @@ def visit_callable_type(self, left: CallableType) -> bool: elif isinstance(right, Instance): if right.type.is_protocol and "__call__" in right.type.protocol_members: # OK, a callable can implement a protocol with a `__call__` member. - # TODO: we should probably explicitly exclude self-types in this case. - call = find_member("__call__", right, left, is_operator=True) + call = find_member("__call__", right, right, is_operator=True) assert call is not None if self._is_subtype(left, call): if len(right.type.protocol_members) == 1: @@ -954,7 +955,7 @@ def visit_overloaded(self, left: Overloaded) -> bool: if isinstance(right, Instance): if right.type.is_protocol and "__call__" in right.type.protocol_members: # same as for CallableType - call = find_member("__call__", right, left, is_operator=True) + call = find_member("__call__", right, right, is_operator=True) assert call is not None if self._is_subtype(left, call): if len(right.type.protocol_members) == 1: @@ -1266,14 +1267,87 @@ def find_member( is_operator: bool = False, class_obj: bool = False, is_lvalue: bool = False, +) -> Type | None: + type_checker = checker_state.type_checker + if type_checker is None: + # Unfortunately, there are many scenarios where someone calls is_subtype() before + # type checking phase. In this case we fallback to old (incomplete) logic. + # TODO: reduce number of such cases (e.g. semanal_typeargs, post-semanal plugins). + return find_member_simple( + name, itype, subtype, is_operator=is_operator, class_obj=class_obj, is_lvalue=is_lvalue + ) + + # We don't use ATTR_DEFINED error code below (since missing attributes can cause various + # other error codes), instead we perform quick node lookup with all the fallbacks. + info = itype.type + sym = info.get(name) + node = sym.node if sym else None + if not node: + name_not_found = True + if ( + name not in ["__getattr__", "__setattr__", "__getattribute__"] + and not is_operator + and not class_obj + and itype.extra_attrs is None # skip ModuleType.__getattr__ + ): + for method_name in ("__getattribute__", "__getattr__"): + method = info.get_method(method_name) + if method and method.info.fullname != "builtins.object": + name_not_found = False + break + if name_not_found: + if info.fallback_to_any or class_obj and info.meta_fallback_to_any: + return AnyType(TypeOfAny.special_form) + if itype.extra_attrs and name in itype.extra_attrs.attrs: + return itype.extra_attrs.attrs[name] + return None + + from mypy.checkmember import ( + MemberContext, + analyze_class_attribute_access, + analyze_instance_member_access, + ) + + mx = MemberContext( + is_lvalue=is_lvalue, + is_super=False, + is_operator=is_operator, + original_type=itype, + self_type=subtype, + context=Context(), # all errors are filtered, but this is a required argument + chk=type_checker, + suppress_errors=True, + # This is needed to avoid infinite recursion in situations involving protocols like + # class P(Protocol[T]): + # def combine(self, other: P[S]) -> P[Tuple[T, S]]: ... + # Normally we call freshen_all_functions_type_vars() during attribute access, + # to avoid type variable id collisions, but for protocols this means we can't + # use the assumption stack, that will grow indefinitely. + # TODO: find a cleaner solution that doesn't involve massive perf impact. + preserve_type_var_ids=True, + ) + with type_checker.msg.filter_errors(filter_deprecated=True): + if class_obj: + fallback = itype.type.metaclass_type or mx.named_type("builtins.type") + return analyze_class_attribute_access(itype, name, mx, mcs_fallback=fallback) + else: + return analyze_instance_member_access(name, itype, mx, info) + + +def find_member_simple( + name: str, + itype: Instance, + subtype: Type, + *, + is_operator: bool = False, + class_obj: bool = False, + is_lvalue: bool = False, ) -> Type | None: """Find the type of member by 'name' in 'itype's TypeInfo. Find the member type after applying type arguments from 'itype', and binding 'self' to 'subtype'. Return None if member was not found. """ - # TODO: this code shares some logic with checkmember.analyze_member_access, - # consider refactoring. info = itype.type method = info.get_method(name) if method: diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 89693a6a7be0..0bcaf94f9e6a 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -3561,6 +3561,21 @@ def foo(x: T): reveal_type(C(0, x)) # N: Revealed type is "__main__.C[__main__.Int[T`-1]]" reveal_type(C("yes", x)) # N: Revealed type is "__main__.C[__main__.Str[T`-1]]" +[case testInstanceMethodBoundOnClass] +from typing import TypeVar, Generic + +T = TypeVar("T") +class B(Generic[T]): + def foo(self) -> T: ... +class C(B[T]): ... +class D(C[int]): ... + +reveal_type(B.foo) # N: Revealed type is "def [T] (self: __main__.B[T`1]) -> T`1" +reveal_type(B[int].foo) # N: Revealed type is "def (self: __main__.B[builtins.int]) -> builtins.int" +reveal_type(C.foo) # N: Revealed type is "def [T] (self: __main__.B[T`1]) -> T`1" +reveal_type(C[int].foo) # N: Revealed type is "def (self: __main__.B[builtins.int]) -> builtins.int" +reveal_type(D.foo) # N: Revealed type is "def (self: __main__.B[builtins.int]) -> builtins.int" + [case testDeterminismFromJoinOrderingInSolver] # Used to fail non-deterministically # https://github.com/python/mypy/issues/19121 diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index 7f11774fbfff..5e34d5223907 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -4460,3 +4460,48 @@ f2(a4) # E: Argument 1 to "f2" has incompatible type "A4"; expected "P2" \ # N: foo: expected "B1", got "str" \ # N: foo: expected setter type "C1", got "str" [builtins fixtures/property.pyi] + +[case testProtocolImplementationWithDescriptors] +from typing import Any, Protocol + +class Descr: + def __get__(self, inst: Any, owner: Any) -> int: ... + +class DescrBad: + def __get__(self, inst: Any, owner: Any) -> str: ... + +class Proto(Protocol): + x: int + +class C: + x = Descr() + +class CBad: + x = DescrBad() + +a: Proto = C() +b: Proto = CBad() # E: Incompatible types in assignment (expression has type "CBad", variable has type "Proto") \ + # N: Following member(s) of "CBad" have conflicts: \ + # N: x: expected "int", got "str" + +[case testProtocolCheckDefersNode] +from typing import Any, Callable, Protocol + +class Proto(Protocol): + def f(self) -> int: + ... + +def defer(f: Callable[[Any], int]) -> Callable[[Any], str]: + ... + +def bad() -> Proto: + return Impl() # E: Incompatible return value type (got "Impl", expected "Proto") \ + # N: Following member(s) of "Impl" have conflicts: \ + # N: Expected: \ + # N: def f(self) -> int \ + # N: Got: \ + # N: def f() -> str \ + +class Impl: + @defer + def f(self) -> int: ... diff --git a/test-data/unit/check-python312.test b/test-data/unit/check-python312.test index 70ab59eb28e4..315c13ab762b 100644 --- a/test-data/unit/check-python312.test +++ b/test-data/unit/check-python312.test @@ -246,6 +246,7 @@ class Invariant[T]: inv1: Invariant[float] = Invariant[int]([1]) # E: Incompatible types in assignment (expression has type "Invariant[int]", variable has type "Invariant[float]") inv2: Invariant[int] = Invariant[float]([1]) # E: Incompatible types in assignment (expression has type "Invariant[float]", variable has type "Invariant[int]") [builtins fixtures/tuple.pyi] +[typing fixtures/typing-full.pyi] [case testPEP695InferVarianceCalculateOnDemand] class Covariant[T]: @@ -1635,8 +1636,8 @@ class M[T: (int, str)](NamedTuple): c: M[int] d: M[str] e: M[bool] # E: Value of type variable "T" of "M" cannot be "bool" - [builtins fixtures/tuple.pyi] +[typing fixtures/typing-full.pyi] [case testPEP695GenericTypedDict] from typing import TypedDict diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index f9d7ce7fc975..4ac69321a250 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2780,7 +2780,7 @@ class TD(TypedDict): reveal_type(TD.__iter__) # N: Revealed type is "def (typing._TypedDict) -> typing.Iterator[builtins.str]" reveal_type(TD.__annotations__) # N: Revealed type is "typing.Mapping[builtins.str, builtins.object]" -reveal_type(TD.values) # N: Revealed type is "def (self: typing.Mapping[T`1, T_co`2]) -> typing.Iterable[T_co`2]" +reveal_type(TD.values) # N: Revealed type is "def (self: typing.Mapping[builtins.str, builtins.object]) -> typing.Iterable[builtins.object]" [builtins fixtures/dict-full.pyi] [typing fixtures/typing-typeddict.pyi]