diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 86e1dc06fc25..be89c2f09a80 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -1484,19 +1484,20 @@ def bind_self_fast(method: F, original_type: Type | None = None) -> F: items = [bind_self_fast(c, original_type) for c in method.items] return cast(F, Overloaded(items)) assert isinstance(method, CallableType) - if not method.arg_types: + func: CallableType = method + if not func.arg_types: # Invalid method, return something. - return cast(F, method) - if method.arg_kinds[0] in (ARG_STAR, ARG_STAR2): + return method + if func.arg_kinds[0] in (ARG_STAR, ARG_STAR2): # See typeops.py for details. - return cast(F, method) + return method original_type = get_proper_type(original_type) if isinstance(original_type, CallableType) and original_type.is_type_obj(): original_type = TypeType.make_normalized(original_type.ret_type) - res = method.copy_modified( - arg_types=method.arg_types[1:], - arg_kinds=method.arg_kinds[1:], - arg_names=method.arg_names[1:], + res = func.copy_modified( + arg_types=func.arg_types[1:], + arg_kinds=func.arg_kinds[1:], + arg_names=func.arg_names[1:], bound_args=[original_type], ) return cast(F, res) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 031f86e7dfff..d27105f48ed3 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -122,7 +122,7 @@ def freshen_function_type_vars(callee: F) -> F: """Substitute fresh type variables for generic function type variables.""" if isinstance(callee, CallableType): if not callee.is_generic(): - return cast(F, callee) + return callee tvs = [] tvmap: dict[TypeVarId, Type] = {} for v in callee.variables: diff --git a/mypy/join.py b/mypy/join.py index 65cc3bef66a4..a012a633dfa3 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -298,7 +298,9 @@ def visit_erased_type(self, t: ErasedType) -> ProperType: def visit_type_var(self, t: TypeVarType) -> ProperType: if isinstance(self.s, TypeVarType) and self.s.id == t.id: - return self.s + if self.s.upper_bound == t.upper_bound: + return self.s + return self.s.copy_modified(upper_bound=join_types(self.s.upper_bound, t.upper_bound)) else: return self.default(self.s) diff --git a/mypy/meet.py b/mypy/meet.py index add0785f5e71..7a44feabc10c 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -50,6 +50,7 @@ find_unpack_in_list, get_proper_type, get_proper_types, + has_type_vars, is_named_instance, split_with_prefix_and_suffix, ) @@ -149,6 +150,14 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: return make_simplified_union( [narrow_declared_type(declared, x) for x in narrowed.relevant_items()] ) + elif ( + isinstance(declared, TypeVarType) + and not has_type_vars(original_narrowed) + and is_subtype(original_narrowed, declared.upper_bound) + ): + # We put this branch early to get T(bound=Union[A, B]) instead of + # Union[T(bound=A), T(bound=B)] that will be confusing for users. + return declared.copy_modified(upper_bound=original_narrowed) elif not is_overlapping_types(declared, narrowed, prohibit_none_typevar_overlap=True): if state.strict_optional: return UninhabitedType() @@ -777,7 +786,9 @@ def visit_erased_type(self, t: ErasedType) -> ProperType: def visit_type_var(self, t: TypeVarType) -> ProperType: if isinstance(self.s, TypeVarType) and self.s.id == t.id: - return self.s + if self.s.upper_bound == t.upper_bound: + return self.s + return self.s.copy_modified(upper_bound=self.meet(self.s.upper_bound, t.upper_bound)) else: return self.default(self.s) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 8d72e44d0eda..15c8014c0f3f 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -632,7 +632,14 @@ def visit_instance(self, left: Instance) -> bool: def visit_type_var(self, left: TypeVarType) -> bool: right = self.right if isinstance(right, TypeVarType) and left.id == right.id: - return True + # Fast path for most common case. + if left.upper_bound == right.upper_bound: + return True + # Corner case for self-types in classes generic in type vars + # with value restrictions. + if left.id.is_self(): + return True + return self._is_subtype(left.upper_bound, right.upper_bound) if left.values and self._is_subtype(UnionType.make_union(left.values), right): return True return self._is_subtype(left.upper_bound, self.right) diff --git a/mypy/typeops.py b/mypy/typeops.py index 3715081ae173..da2796ff5dec 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -415,10 +415,10 @@ class B(A): pass ] return cast(F, Overloaded(items)) assert isinstance(method, CallableType) - func = method + func: CallableType = method if not func.arg_types: # Invalid method, return something. - return cast(F, func) + return method if func.arg_kinds[0] in (ARG_STAR, ARG_STAR2): # The signature is of the form 'def foo(*args, ...)'. # In this case we shouldn't drop the first arg, @@ -427,7 +427,7 @@ class B(A): pass # In the case of **kwargs we should probably emit an error, but # for now we simply skip it, to avoid crashes down the line. - return cast(F, func) + return method self_param_type = get_proper_type(func.arg_types[0]) variables: Sequence[TypeVarLikeType] diff --git a/mypy/types.py b/mypy/types.py index d2094cd15774..d83b320106ab 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -461,6 +461,11 @@ def __init__(self, type_guard: Type) -> None: def __repr__(self) -> str: return f"TypeGuard({self.type_guard})" + # This may hide some real bugs, but it is convenient for various "synthetic" + # visitors, similar to RequiredType and ReadOnlyType below. + def accept(self, visitor: TypeVisitor[T]) -> T: + return self.type_guard.accept(visitor) + class RequiredType(Type): """Required[T] or NotRequired[T]. Only usable at top-level of a TypedDict definition.""" diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index 288f281c0a94..b98f1989da51 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -2983,3 +2983,31 @@ class B(native.A): b: B = B.make() assert(B.count == 2) + +[case testTypeVarNarrowing] +from typing import TypeVar + +class B: + def __init__(self, x: int) -> None: + self.x = x +class C(B): + def __init__(self, x: int, y: str) -> None: + self.x = x + self.y = y + +T = TypeVar("T", bound=B) +def f(x: T) -> T: + if isinstance(x, C): + print("C", x.y) + return x + print("B", x.x) + return x + +[file driver.py] +from native import f, B, C + +f(B(1)) +f(C(1, "yes")) +[out] +B 1 +C yes diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 054ba0708ce3..9c95458361fd 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -6891,10 +6891,11 @@ reveal_type(i.x) # N: Revealed type is "builtins.int" [builtins fixtures/isinstancelist.pyi] [case testIsInstanceTypeTypeVar] -from typing import Type, TypeVar, Generic +from typing import Type, TypeVar, Generic, ClassVar class Base: ... -class Sub(Base): ... +class Sub(Base): + other: ClassVar[int] T = TypeVar('T', bound=Base) @@ -6902,13 +6903,9 @@ class C(Generic[T]): def meth(self, cls: Type[T]) -> None: if not issubclass(cls, Sub): return - reveal_type(cls) # N: Revealed type is "type[__main__.Sub]" - def other(self, cls: Type[T]) -> None: - if not issubclass(cls, Sub): - return - reveal_type(cls) # N: Revealed type is "type[__main__.Sub]" - -[builtins fixtures/isinstancelist.pyi] + reveal_type(cls) # N: Revealed type is "type[T`1]" + reveal_type(cls.other) # N: Revealed type is "builtins.int" +[builtins fixtures/isinstance.pyi] [case testIsInstanceTypeSubclass] from typing import Type, Optional @@ -7602,7 +7599,7 @@ class C1: class C2(Generic[TypeT]): def method(self, other: TypeT) -> int: if issubclass(other, Base): - reveal_type(other) # N: Revealed type is "type[__main__.Base]" + reveal_type(other) # N: Revealed type is "TypeT`1" return other.field return 0 diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index fe08d2cfc699..640fc10915d1 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -1821,19 +1821,23 @@ if issubclass(fm, Baz): from typing import TypeVar class A: pass -class B(A): pass +class B(A): + attr: int T = TypeVar('T', bound=A) def f(x: T) -> None: if isinstance(x, B): - reveal_type(x) # N: Revealed type is "__main__.B" + reveal_type(x) # N: Revealed type is "T`-1" + reveal_type(x.attr) # N: Revealed type is "builtins.int" else: reveal_type(x) # N: Revealed type is "T`-1" + x.attr # E: "T" has no attribute "attr" reveal_type(x) # N: Revealed type is "T`-1" + x.attr # E: "T" has no attribute "attr" [builtins fixtures/isinstance.pyi] -[case testIsinstanceAndNegativeNarrowTypeVariableWithUnionBound] +[case testIsinstanceAndNegativeNarrowTypeVariableWithUnionBound1] from typing import Union, TypeVar class A: @@ -1845,9 +1849,11 @@ T = TypeVar("T", bound=Union[A, B]) def f(x: T) -> T: if isinstance(x, A): - reveal_type(x) # N: Revealed type is "__main__.A" + reveal_type(x) # N: Revealed type is "T`-1" x.a - x.b # E: "A" has no attribute "b" + x.b # E: "T" has no attribute "b" + if bool(): + return x else: reveal_type(x) # N: Revealed type is "T`-1" x.a # E: "T" has no attribute "a" @@ -1857,6 +1863,24 @@ def f(x: T) -> T: return x [builtins fixtures/isinstance.pyi] +[case testIsinstanceAndNegativeNarrowTypeVariableWithUnionBound2] +from typing import Union, TypeVar + +class A: + a: int +class B: + b: int + +T = TypeVar("T", bound=Union[A, B]) + +def f(x: T) -> T: + if isinstance(x, A): + return x + x.a # E: "T" has no attribute "a" + x.b # OK + return x +[builtins fixtures/isinstance.pyi] + [case testIsinstanceAndTypeType] from typing import Type def f(x: Type[int]) -> None: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 4afed0e3ec86..36b2ced075d2 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2424,3 +2424,42 @@ def f() -> None: assert isinstance(x, int) reveal_type(x) # N: Revealed type is "builtins.int" [builtins fixtures/isinstance.pyi] + +[case testNarrowTypeVarBoundType] +from typing import Type, TypeVar + +class A: ... +class B(A): + other: int + +T = TypeVar("T", bound=A) +def test(cls: Type[T]) -> T: + if issubclass(cls, B): + reveal_type(cls) # N: Revealed type is "type[T`-1]" + reveal_type(cls().other) # N: Revealed type is "builtins.int" + return cls() + return cls() +[builtins fixtures/isinstance.pyi] + +[case testNarrowTypeVarBoundUnion] +from typing import TypeVar + +class A: + x: int +class B: + x: str + +T = TypeVar("T") +def test(x: T) -> T: + if not isinstance(x, (A, B)): + return x + reveal_type(x) # N: Revealed type is "T`-1" + reveal_type(x.x) # N: Revealed type is "Union[builtins.int, builtins.str]" + if isinstance(x, A): + reveal_type(x) # N: Revealed type is "T`-1" + reveal_type(x.x) # N: Revealed type is "builtins.int" + return x + reveal_type(x) # N: Revealed type is "T`-1" + reveal_type(x.x) # N: Revealed type is "builtins.str" + return x +[builtins fixtures/isinstance.pyi] diff --git a/test-data/unit/check-typeguard.test b/test-data/unit/check-typeguard.test index 0b512962b8d1..c43eead67876 100644 --- a/test-data/unit/check-typeguard.test +++ b/test-data/unit/check-typeguard.test @@ -778,6 +778,23 @@ def handle(model: Model) -> int: return 0 [builtins fixtures/tuple.pyi] +[case testTypeGuardRestrictTypeVarUnion] +from typing import Union, TypeVar +from typing_extensions import TypeGuard + +class A: + x: int +class B: + x: str + +def is_b(x: object) -> TypeGuard[B]: ... + +T = TypeVar("T") +def test(x: T) -> T: + if isinstance(x, A) or is_b(x): + reveal_type(x.x) # N: Revealed type is "Union[builtins.int, builtins.str]" + return x +[builtins fixtures/isinstance.pyi] [case testOverloadedTypeGuardType] from __future__ import annotations