From 3a3cdf84589b1e795755a8256d67f223a762c80e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 19 Apr 2025 17:07:10 +0100 Subject: [PATCH 1/4] Use checkmember.py to check protocol subtyping --- mypy/checker.py | 4 +- mypy/checkmember.py | 55 +++++++++---------- mypy/expandtype.py | 3 +- mypy/messages.py | 9 +++- mypy/plugin.py | 8 +-- mypy/state.py | 20 +++++-- mypy/subtypes.py | 83 +++++++++++++++++++++++++++-- mypy/types.py | 3 +- test-data/unit/check-python312.test | 3 +- test-data/unit/check-typeddict.test | 2 +- 10 files changed, 139 insertions(+), 51 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 7d0b41c516e1..19bf24327dc8 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -455,7 +455,7 @@ def check_first_pass(self) -> None: Deferred functions will be processed by check_second_pass(). """ self.recurse_into_functions = True - with state.strict_optional_set(self.options.strict_optional): + with state.strict_optional_set(self.options.strict_optional), state.type_checker_set(self): self.errors.set_file( self.path, self.tree.fullname, scope=self.tscope, options=self.options ) @@ -496,7 +496,7 @@ def check_second_pass( This goes through deferred nodes, returning True if there were any. """ self.recurse_into_functions = True - with state.strict_optional_set(self.options.strict_optional): + with state.strict_optional_set(self.options.strict_optional), state.type_checker_set(self): if not todo and not self.deferred_nodes: return False self.errors.set_file( diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 1a76372d4731..b2e443c82e80 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -96,6 +96,7 @@ def __init__( is_self: bool = False, rvalue: Expression | None = None, suppress_errors: bool = False, + preserve_type_var_ids: bool = False, ) -> None: self.is_lvalue = is_lvalue self.is_super = is_super @@ -112,6 +113,10 @@ def __init__( assert is_lvalue self.rvalue = rvalue self.suppress_errors = suppress_errors + # This attribute is only used to preserve old protocol member access logic. + # It is needed to avoid infinite recursion in cases involving self-referential + # generic methods, see find_member() for details. Do not use for other purposes! + self.preserve_type_var_ids = preserve_type_var_ids def named_type(self, name: str) -> Instance: return self.chk.named_type(name) @@ -142,6 +147,7 @@ def copy_modified( no_deferral=self.no_deferral, rvalue=self.rvalue, suppress_errors=self.suppress_errors, + preserve_type_var_ids=self.preserve_type_var_ids, ) if self_type is not None: mx.self_type = self_type @@ -231,8 +237,6 @@ def analyze_member_access( def _analyze_member_access( name: str, typ: Type, mx: MemberContext, override_info: TypeInfo | None = None ) -> Type: - # TODO: This and following functions share some logic with subtypes.find_member; - # consider refactoring. typ = get_proper_type(typ) if isinstance(typ, Instance): return analyze_instance_member_access(name, typ, mx, override_info) @@ -355,7 +359,8 @@ def analyze_instance_member_access( return AnyType(TypeOfAny.special_form) assert isinstance(method.type, Overloaded) signature = method.type - signature = freshen_all_functions_type_vars(signature) + if not mx.preserve_type_var_ids: + signature = freshen_all_functions_type_vars(signature) if not method.is_static: signature = check_self_arg( signature, mx.self_type, method.is_class, mx.context, name, mx.msg @@ -928,7 +933,8 @@ def analyze_var( def expand_without_binding( typ: Type, var: Var, itype: Instance, original_itype: Instance, mx: MemberContext ) -> Type: - typ = freshen_all_functions_type_vars(typ) + if not mx.preserve_type_var_ids: + typ = freshen_all_functions_type_vars(typ) typ = expand_self_type_if_needed(typ, mx, var, original_itype) expanded = expand_type_by_instance(typ, itype) freeze_all_type_vars(expanded) @@ -938,7 +944,8 @@ def expand_without_binding( def expand_and_bind_callable( functype: FunctionLike, var: Var, itype: Instance, name: str, mx: MemberContext ) -> Type: - functype = freshen_all_functions_type_vars(functype) + if not mx.preserve_type_var_ids: + functype = freshen_all_functions_type_vars(functype) typ = get_proper_type(expand_self_type(var, functype, mx.original_type)) assert isinstance(typ, FunctionLike) typ = check_self_arg(typ, mx.self_type, var.is_classmethod, mx.context, name, mx.msg) @@ -1033,10 +1040,12 @@ def f(self: S) -> T: ... return functype else: selfarg = get_proper_type(item.arg_types[0]) - # This level of erasure matches the one in checker.check_func_def(), - # better keep these two checks consistent. - if subtypes.is_subtype( + # This matches similar special-casing in bind_self(), see more details there. + self_callable = name == "__call__" and isinstance(selfarg, CallableType) + if self_callable or subtypes.is_subtype( dispatched_arg_type, + # This level of erasure matches the one in checker.check_func_def(), + # better keep these two checks consistent. erase_typevars(erase_to_bound(selfarg)), # This is to work around the fact that erased ParamSpec and TypeVarTuple # callables are not always compatible with non-erased ones both ways. @@ -1197,15 +1206,10 @@ def analyze_class_attribute_access( is_classmethod = (is_decorated and cast(Decorator, node.node).func.is_class) or ( isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_class ) - is_staticmethod = (is_decorated and cast(Decorator, node.node).func.is_static) or ( - isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_static - ) t = get_proper_type(t) if isinstance(t, FunctionLike) and is_classmethod: t = check_self_arg(t, mx.self_type, False, mx.context, name, mx.msg) - result = add_class_tvars( - t, isuper, is_classmethod, is_staticmethod, mx.self_type, original_vars=original_vars - ) + result = add_class_tvars(t, isuper, is_classmethod, mx, original_vars=original_vars) # __set__ is not called on class objects. if not mx.is_lvalue: result = analyze_descriptor_access(result, mx) @@ -1337,8 +1341,7 @@ def add_class_tvars( t: ProperType, isuper: Instance | None, is_classmethod: bool, - is_staticmethod: bool, - original_type: Type, + mx: MemberContext, original_vars: Sequence[TypeVarLikeType] | None = None, ) -> Type: """Instantiate type variables during analyze_class_attribute_access, @@ -1356,9 +1359,6 @@ class B(A[str]): pass isuper: Current instance mapped to the superclass where method was defined, this is usually done by map_instance_to_supertype() is_classmethod: True if this method is decorated with @classmethod - is_staticmethod: True if this method is decorated with @staticmethod - original_type: The value of the type B in the expression B.foo() or the corresponding - component in case of a union (this is used to bind the self-types) original_vars: Type variables of the class callable on which the method was accessed Returns: Expanded method type with added type variables (when needed). @@ -1379,11 +1379,11 @@ class B(A[str]): pass # (i.e. appear in the return type of the class object on which the method was accessed). if isinstance(t, CallableType): tvars = original_vars if original_vars is not None else [] - t = freshen_all_functions_type_vars(t) + if not mx.preserve_type_var_ids: + t = freshen_all_functions_type_vars(t) if is_classmethod: - t = bind_self(t, original_type, is_classmethod=True) - if is_classmethod or is_staticmethod: - assert isuper is not None + t = bind_self(t, mx.self_type, is_classmethod=True) + if isuper is not None: t = expand_type_by_instance(t, isuper) freeze_all_type_vars(t) return t.copy_modified(variables=list(tvars) + list(t.variables)) @@ -1392,14 +1392,7 @@ class B(A[str]): pass [ cast( CallableType, - add_class_tvars( - item, - isuper, - is_classmethod, - is_staticmethod, - original_type, - original_vars=original_vars, - ), + add_class_tvars(item, isuper, is_classmethod, mx, original_vars=original_vars), ) for item in t.items ] diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 031f86e7dfff..f17d3ecfcd83 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -4,7 +4,6 @@ from typing import Final, TypeVar, cast, overload from mypy.nodes import ARG_STAR, FakeInfo, Var -from mypy.state import state from mypy.types import ( ANY_STRATEGY, AnyType, @@ -544,6 +543,8 @@ def remove_trivial(types: Iterable[Type]) -> list[Type]: * Remove everything else if there is an `object` * Remove strict duplicate types """ + from mypy.state import state + removed_none = False new_types = [] all_types = set() diff --git a/mypy/messages.py b/mypy/messages.py index 2e07d7f63498..d18e9917a095 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -2220,8 +2220,13 @@ def report_protocol_problems( exp = get_proper_type(exp) got = get_proper_type(got) setter_suffix = " setter type" if is_lvalue else "" - if not isinstance(exp, (CallableType, Overloaded)) or not isinstance( - got, (CallableType, Overloaded) + if ( + not isinstance(exp, (CallableType, Overloaded)) + or not isinstance(got, (CallableType, Overloaded)) + # If expected type is a type object, it means it is a nested class. + # Showing constructor signature in errors would be confusing in this case, + # since we don't check the signature, only subclassing of type objects. + or exp.is_type_obj() ): self.note( "{}: expected{} {}, got {}".format( diff --git a/mypy/plugin.py b/mypy/plugin.py index 39841d5b907a..de075866d613 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -119,14 +119,13 @@ class C: pass from __future__ import annotations from abc import abstractmethod -from typing import Any, Callable, NamedTuple, TypeVar +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar from mypy_extensions import mypyc_attr, trait from mypy.errorcodes import ErrorCode from mypy.lookup import lookup_fully_qualified from mypy.message_registry import ErrorMessage -from mypy.messages import MessageBuilder from mypy.nodes import ( ArgKind, CallExpr, @@ -138,7 +137,6 @@ class C: pass TypeInfo, ) from mypy.options import Options -from mypy.tvar_scope import TypeVarLikeScope from mypy.types import ( CallableType, FunctionLike, @@ -149,6 +147,10 @@ class C: pass UnboundType, ) +if TYPE_CHECKING: + from mypy.messages import MessageBuilder + from mypy.tvar_scope import TypeVarLikeScope + @trait class TypeAnalyzerPluginInterface: diff --git a/mypy/state.py b/mypy/state.py index a3055bf6b208..41b8b75be127 100644 --- a/mypy/state.py +++ b/mypy/state.py @@ -4,16 +4,19 @@ from contextlib import contextmanager from typing import Final +from mypy.checker_shared import TypeCheckerSharedApi + # These are global mutable state. Don't add anything here unless there's a very # good reason. -class StrictOptionalState: +class SubtypeState: # Wrap this in a class since it's faster that using a module-level attribute. - def __init__(self, strict_optional: bool) -> None: - # Value varies by file being processed + def __init__(self, strict_optional: bool, type_checker: TypeCheckerSharedApi | None) -> None: + # Values vary by file being processed self.strict_optional = strict_optional + self.type_checker = type_checker @contextmanager def strict_optional_set(self, value: bool) -> Iterator[None]: @@ -24,6 +27,15 @@ def strict_optional_set(self, value: bool) -> Iterator[None]: finally: self.strict_optional = saved + @contextmanager + def type_checker_set(self, value: TypeCheckerSharedApi) -> Iterator[None]: + saved = self.type_checker + self.type_checker = value + try: + yield + finally: + self.type_checker = saved + -state: Final = StrictOptionalState(strict_optional=True) +state: Final = SubtypeState(strict_optional=True, type_checker=None) find_occurrences: tuple[str, str] | None = None diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 71b8b0ba59f5..226c39bb2933 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -26,6 +26,7 @@ COVARIANT, INVARIANT, VARIANCE_NOT_READY, + Context, Decorator, FuncBase, OverloadedFuncDef, @@ -717,8 +718,7 @@ def visit_callable_type(self, left: CallableType) -> bool: elif isinstance(right, Instance): if right.type.is_protocol and "__call__" in right.type.protocol_members: # OK, a callable can implement a protocol with a `__call__` member. - # TODO: we should probably explicitly exclude self-types in this case. - call = find_member("__call__", right, left, is_operator=True) + call = find_member("__call__", right, right, is_operator=True) assert call is not None if self._is_subtype(left, call): if len(right.type.protocol_members) == 1: @@ -954,7 +954,7 @@ def visit_overloaded(self, left: Overloaded) -> bool: if isinstance(right, Instance): if right.type.is_protocol and "__call__" in right.type.protocol_members: # same as for CallableType - call = find_member("__call__", right, left, is_operator=True) + call = find_member("__call__", right, right, is_operator=True) assert call is not None if self._is_subtype(left, call): if len(right.type.protocol_members) == 1: @@ -1261,14 +1261,87 @@ def find_member( is_operator: bool = False, class_obj: bool = False, is_lvalue: bool = False, +) -> Type | None: + type_checker = state.type_checker + if type_checker is None: + # Unfortunately, there are many scenarios where someone calls is_subtype() before + # type checking phase. In this case we fallback to old (incomplete) logic. + # TODO: reduce number of such cases (e.g. semanal_typeargs, post-semanal plugins). + return find_member_simple( + name, itype, subtype, is_operator=is_operator, class_obj=class_obj, is_lvalue=is_lvalue + ) + + # We don't use ATTR_DEFINED error code below (since missing attributes can cause various + # other error codes), instead we perform quick node lookup with all the fallbacks. + info = itype.type + sym = info.get(name) + node = sym.node if sym else None + if not node: + name_not_found = True + if ( + name not in ["__getattr__", "__setattr__", "__getattribute__"] + and not is_operator + and not class_obj + and itype.extra_attrs is None # skip ModuleType.__getattr__ + ): + for method_name in ("__getattribute__", "__getattr__"): + method = info.get_method(method_name) + if method and method.info.fullname != "builtins.object": + name_not_found = False + break + if name_not_found: + if info.fallback_to_any or class_obj and info.meta_fallback_to_any: + return AnyType(TypeOfAny.special_form) + if itype.extra_attrs and name in itype.extra_attrs.attrs: + return itype.extra_attrs.attrs[name] + return None + + from mypy.checkmember import ( + MemberContext, + analyze_class_attribute_access, + analyze_instance_member_access, + ) + + mx = MemberContext( + is_lvalue=is_lvalue, + is_super=False, + is_operator=is_operator, + original_type=itype, + self_type=subtype, + context=Context(), # all errors are filtered, but this is a required argument + chk=type_checker, + suppress_errors=True, + # This is needed to avoid infinite recursion in situations involving protocols like + # class P(Protocol[T]): + # def combine(self, other: P[S]) -> P[Tuple[T, S]]: ... + # Normally we call freshen_all_functions_type_vars() during attribute access, + # to avoid type variable id collisions, but for protocols this means we can't + # use the assumption stack, that will grow indefinitely. + # TODO: find a cleaner solution that doesn't involve massive perf impact. + preserve_type_var_ids=True, + ) + with type_checker.msg.filter_errors(filter_deprecated=True): + if class_obj: + fallback = itype.type.metaclass_type or mx.named_type("builtins.type") + return analyze_class_attribute_access(itype, name, mx, mcs_fallback=fallback) + else: + return analyze_instance_member_access(name, itype, mx, info) + + +def find_member_simple( + name: str, + itype: Instance, + subtype: Type, + *, + is_operator: bool = False, + class_obj: bool = False, + is_lvalue: bool = False, ) -> Type | None: """Find the type of member by 'name' in 'itype's TypeInfo. Find the member type after applying type arguments from 'itype', and binding 'self' to 'subtype'. Return None if member was not found. """ - # TODO: this code shares some logic with checkmember.analyze_member_access, - # consider refactoring. info = itype.type method = info.get_method(name) if method: diff --git a/mypy/types.py b/mypy/types.py index 41a958ae93cc..a922f64a47a8 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -32,7 +32,6 @@ SymbolNode, ) from mypy.options import Options -from mypy.state import state from mypy.util import IdMapper T = TypeVar("T") @@ -2979,6 +2978,8 @@ def accept(self, visitor: TypeVisitor[T]) -> T: def relevant_items(self) -> list[Type]: """Removes NoneTypes from Unions when strict Optional checking is off.""" + from mypy.state import state + if state.strict_optional: return self.items else: diff --git a/test-data/unit/check-python312.test b/test-data/unit/check-python312.test index 2f3d5e08dab3..54864b24ea40 100644 --- a/test-data/unit/check-python312.test +++ b/test-data/unit/check-python312.test @@ -246,6 +246,7 @@ class Invariant[T]: inv1: Invariant[float] = Invariant[int]([1]) # E: Incompatible types in assignment (expression has type "Invariant[int]", variable has type "Invariant[float]") inv2: Invariant[int] = Invariant[float]([1]) # E: Incompatible types in assignment (expression has type "Invariant[float]", variable has type "Invariant[int]") [builtins fixtures/tuple.pyi] +[typing fixtures/typing-full.pyi] [case testPEP695InferVarianceCalculateOnDemand] class Covariant[T]: @@ -1635,8 +1636,8 @@ class M[T: (int, str)](NamedTuple): c: M[int] d: M[str] e: M[bool] # E: Value of type variable "T" of "M" cannot be "bool" - [builtins fixtures/tuple.pyi] +[typing fixtures/typing-full.pyi] [case testPEP695GenericTypedDict] from typing import TypedDict diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 47c8a71ba0e3..5d6706c35308 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2780,7 +2780,7 @@ class TD(TypedDict): reveal_type(TD.__iter__) # N: Revealed type is "def (typing._TypedDict) -> typing.Iterator[builtins.str]" reveal_type(TD.__annotations__) # N: Revealed type is "typing.Mapping[builtins.str, builtins.object]" -reveal_type(TD.values) # N: Revealed type is "def (self: typing.Mapping[T`1, T_co`2]) -> typing.Iterable[T_co`2]" +reveal_type(TD.values) # N: Revealed type is "def (self: typing.Mapping[builtins.str, builtins.object]) -> typing.Iterable[builtins.object]" [builtins fixtures/dict-full.pyi] [typing fixtures/typing-typeddict.pyi] From 9c3d3032e91913a2dcd91af494e16b5acdf20d7d Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Sat, 19 Apr 2025 13:33:19 -0700 Subject: [PATCH 2/4] show mypy primer speed regression --- .github/workflows/mypy_primer.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/mypy_primer.yml b/.github/workflows/mypy_primer.yml index ee868484751e..532e77a0cacb 100644 --- a/.github/workflows/mypy_primer.yml +++ b/.github/workflows/mypy_primer.yml @@ -67,6 +67,7 @@ jobs: --debug \ --additional-flags="--debug-serialize" \ --output concise \ + --show-speed-regression \ | tee diff_${{ matrix.shard-index }}.txt ) || [ $? -eq 1 ] - if: ${{ matrix.shard-index == 0 }} From 5b3190a1670ee5ddcc0e9f0d4d92c251d3243a63 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 11 May 2025 00:09:19 +0100 Subject: [PATCH 3/4] Put type checker global state in separate module --- mypy/checker.py | 5 +++-- mypy/checker_state.py | 30 ++++++++++++++++++++++++++++++ mypy/expandtype.py | 3 +-- mypy/state.py | 20 ++++---------------- mypy/subtypes.py | 3 ++- mypy/types.py | 3 +-- 6 files changed, 41 insertions(+), 23 deletions(-) create mode 100644 mypy/checker_state.py diff --git a/mypy/checker.py b/mypy/checker.py index 67ff66c7994f..76c86262b365 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -13,6 +13,7 @@ from mypy import errorcodes as codes, join, message_registry, nodes, operators from mypy.binder import ConditionalTypeBinder, Frame, get_declaration from mypy.checker_shared import CheckerScope, TypeCheckerSharedApi, TypeRange +from mypy.checker_state import checker_state from mypy.checkmember import ( MemberContext, analyze_class_attribute_access, @@ -455,7 +456,7 @@ def check_first_pass(self) -> None: Deferred functions will be processed by check_second_pass(). """ self.recurse_into_functions = True - with state.strict_optional_set(self.options.strict_optional), state.type_checker_set(self): + with state.strict_optional_set(self.options.strict_optional), checker_state.set(self): self.errors.set_file( self.path, self.tree.fullname, scope=self.tscope, options=self.options ) @@ -496,7 +497,7 @@ def check_second_pass( This goes through deferred nodes, returning True if there were any. """ self.recurse_into_functions = True - with state.strict_optional_set(self.options.strict_optional), state.type_checker_set(self): + with state.strict_optional_set(self.options.strict_optional), checker_state.set(self): if not todo and not self.deferred_nodes: return False self.errors.set_file( diff --git a/mypy/checker_state.py b/mypy/checker_state.py new file mode 100644 index 000000000000..9b988ad18ba4 --- /dev/null +++ b/mypy/checker_state.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Final + +from mypy.checker_shared import TypeCheckerSharedApi + +# This is global mutable state. Don't add anything here unless there's a very +# good reason. + + +class TypeCheckerState: + # Wrap this in a class since it's faster that using a module-level attribute. + + def __init__(self, type_checker: TypeCheckerSharedApi | None) -> None: + # Value varies by file being processed + self.type_checker = type_checker + + @contextmanager + def set(self, value: TypeCheckerSharedApi) -> Iterator[None]: + saved = self.type_checker + self.type_checker = value + try: + yield + finally: + self.type_checker = saved + + +checker_state: Final = TypeCheckerState(type_checker=None) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index f17d3ecfcd83..031f86e7dfff 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -4,6 +4,7 @@ from typing import Final, TypeVar, cast, overload from mypy.nodes import ARG_STAR, FakeInfo, Var +from mypy.state import state from mypy.types import ( ANY_STRATEGY, AnyType, @@ -543,8 +544,6 @@ def remove_trivial(types: Iterable[Type]) -> list[Type]: * Remove everything else if there is an `object` * Remove strict duplicate types """ - from mypy.state import state - removed_none = False new_types = [] all_types = set() diff --git a/mypy/state.py b/mypy/state.py index 41b8b75be127..a3055bf6b208 100644 --- a/mypy/state.py +++ b/mypy/state.py @@ -4,19 +4,16 @@ from contextlib import contextmanager from typing import Final -from mypy.checker_shared import TypeCheckerSharedApi - # These are global mutable state. Don't add anything here unless there's a very # good reason. -class SubtypeState: +class StrictOptionalState: # Wrap this in a class since it's faster that using a module-level attribute. - def __init__(self, strict_optional: bool, type_checker: TypeCheckerSharedApi | None) -> None: - # Values vary by file being processed + def __init__(self, strict_optional: bool) -> None: + # Value varies by file being processed self.strict_optional = strict_optional - self.type_checker = type_checker @contextmanager def strict_optional_set(self, value: bool) -> Iterator[None]: @@ -27,15 +24,6 @@ def strict_optional_set(self, value: bool) -> Iterator[None]: finally: self.strict_optional = saved - @contextmanager - def type_checker_set(self, value: TypeCheckerSharedApi) -> Iterator[None]: - saved = self.type_checker - self.type_checker = value - try: - yield - finally: - self.type_checker = saved - -state: Final = SubtypeState(strict_optional=True, type_checker=None) +state: Final = StrictOptionalState(strict_optional=True) find_occurrences: tuple[str, str] | None = None diff --git a/mypy/subtypes.py b/mypy/subtypes.py index d90e86ec1558..8d72e44d0eda 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -8,6 +8,7 @@ import mypy.applytype import mypy.constraints import mypy.typeops +from mypy.checker_state import checker_state from mypy.erasetype import erase_type from mypy.expandtype import ( expand_self_type, @@ -1267,7 +1268,7 @@ def find_member( class_obj: bool = False, is_lvalue: bool = False, ) -> Type | None: - type_checker = state.type_checker + type_checker = checker_state.type_checker if type_checker is None: # Unfortunately, there are many scenarios where someone calls is_subtype() before # type checking phase. In this case we fallback to old (incomplete) logic. diff --git a/mypy/types.py b/mypy/types.py index a922f64a47a8..41a958ae93cc 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -32,6 +32,7 @@ SymbolNode, ) from mypy.options import Options +from mypy.state import state from mypy.util import IdMapper T = TypeVar("T") @@ -2978,8 +2979,6 @@ def accept(self, visitor: TypeVisitor[T]) -> T: def relevant_items(self) -> list[Type]: """Removes NoneTypes from Unions when strict Optional checking is off.""" - from mypy.state import state - if state.strict_optional: return self.items else: From 689c119128defaa688779e0a999752b6d1b7c032 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 19 May 2025 10:54:26 +0100 Subject: [PATCH 4/4] Add dedicated tests for key resolved issues --- test-data/unit/check-generics.test | 15 ++++++++++ test-data/unit/check-protocols.test | 45 +++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 767b55efcac2..b801bba14069 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -3563,3 +3563,18 @@ def foo(x: T): reveal_type(C) # N: Revealed type is "Overload(def [T, S] (x: builtins.int, y: S`-1) -> __main__.C[__main__.Int[S`-1]], def [T, S] (x: builtins.str, y: S`-1) -> __main__.C[__main__.Str[S`-1]])" reveal_type(C(0, x)) # N: Revealed type is "__main__.C[__main__.Int[T`-1]]" reveal_type(C("yes", x)) # N: Revealed type is "__main__.C[__main__.Str[T`-1]]" + +[case testInstanceMethodBoundOnClass] +from typing import TypeVar, Generic + +T = TypeVar("T") +class B(Generic[T]): + def foo(self) -> T: ... +class C(B[T]): ... +class D(C[int]): ... + +reveal_type(B.foo) # N: Revealed type is "def [T] (self: __main__.B[T`1]) -> T`1" +reveal_type(B[int].foo) # N: Revealed type is "def (self: __main__.B[builtins.int]) -> builtins.int" +reveal_type(C.foo) # N: Revealed type is "def [T] (self: __main__.B[T`1]) -> T`1" +reveal_type(C[int].foo) # N: Revealed type is "def (self: __main__.B[builtins.int]) -> builtins.int" +reveal_type(D.foo) # N: Revealed type is "def (self: __main__.B[builtins.int]) -> builtins.int" diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index 34e3f3e88080..0cebb20c2cf3 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -4460,3 +4460,48 @@ f2(a4) # E: Argument 1 to "f2" has incompatible type "A4"; expected "P2" \ # N: foo: expected "B1", got "str" \ # N: foo: expected setter type "C1", got "str" [builtins fixtures/property.pyi] + +[case testProtocolImplementationWithDescriptors] +from typing import Any, Protocol + +class Descr: + def __get__(self, inst: Any, owner: Any) -> int: ... + +class DescrBad: + def __get__(self, inst: Any, owner: Any) -> str: ... + +class Proto(Protocol): + x: int + +class C: + x = Descr() + +class CBad: + x = DescrBad() + +a: Proto = C() +b: Proto = CBad() # E: Incompatible types in assignment (expression has type "CBad", variable has type "Proto") \ + # N: Following member(s) of "CBad" have conflicts: \ + # N: x: expected "int", got "str" + +[case testProtocolCheckDefersNode] +from typing import Any, Callable, Protocol + +class Proto(Protocol): + def f(self) -> int: + ... + +def defer(f: Callable[[Any], int]) -> Callable[[Any], str]: + ... + +def bad() -> Proto: + return Impl() # E: Incompatible return value type (got "Impl", expected "Proto") \ + # N: Following member(s) of "Impl" have conflicts: \ + # N: Expected: \ + # N: def f(self) -> int \ + # N: Got: \ + # N: def f() -> str \ + +class Impl: + @defer + def f(self) -> int: ...