diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index aa6d8e63f5f7..0753ee80c113 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -154,6 +154,7 @@ is_optional, remove_optional, ) +from mypy.typestate import TypeState from mypy.typevars import fill_typevars from mypy.util import split_module_names from mypy.visitor import ExpressionVisitor @@ -1429,6 +1430,22 @@ def infer_arg_types_in_empty_context(self, args: List[Expression]) -> List[Type] res.append(arg_type) return res + @contextmanager + def allow_unions(self, type_context: Type) -> Iterator[None]: + # This is a hack to better support inference for recursive types. + # When the outer context for a function call is known to be recursive, + # we solve type constraints inferred from arguments using unions instead + # of joins. This is a bit arbitrary, but in practice it works for most + # cases. A cleaner alternative would be to switch to single bin type + # inference, but this is a lot of work. + old = TypeState.infer_unions + if has_recursive_types(type_context): + TypeState.infer_unions = True + try: + yield + finally: + TypeState.infer_unions = old + def infer_arg_types_in_context( self, callee: CallableType, @@ -1448,7 +1465,8 @@ def infer_arg_types_in_context( for i, actuals in enumerate(formal_to_actual): for ai in actuals: if not arg_kinds[ai].is_star(): - res[ai] = self.accept(args[ai], callee.arg_types[i]) + with self.allow_unions(callee.arg_types[i]): + res[ai] = self.accept(args[ai], callee.arg_types[i]) # Fill in the rest of the argument types. for i, t in enumerate(res): @@ -1568,17 +1586,6 @@ def infer_function_type_arguments( else: pass1_args.append(arg) - # This is a hack to better support inference for recursive types. - # When the outer context for a function call is known to be recursive, - # we solve type constraints inferred from arguments using unions instead - # of joins. This is a bit arbitrary, but in practice it works for most - # cases. A cleaner alternative would be to switch to single bin type - # inference, but this is a lot of work. - ctx = self.type_context[-1] - if ctx and has_recursive_types(ctx): - infer_unions = True - else: - infer_unions = False inferred_args = infer_function_type_arguments( callee_type, pass1_args, @@ -1586,7 +1593,6 @@ def infer_function_type_arguments( formal_to_actual, context=self.argument_infer_context(), strict=self.chk.in_checked_function(), - infer_unions=infer_unions, ) if 2 in arg_pass_nums: diff --git a/mypy/constraints.py b/mypy/constraints.py index 0ca6a3e085f0..b4c3cf6f28c9 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -42,6 +42,8 @@ UnpackType, callable_with_ellipsis, get_proper_type, + has_recursive_types, + has_type_vars, is_named_instance, is_union_with_any, ) @@ -141,14 +143,19 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> List[Cons The constraints are represented as Constraint objects. """ if any( - get_proper_type(template) == get_proper_type(t) for t in reversed(TypeState._inferring) + get_proper_type(template) == get_proper_type(t) + and get_proper_type(actual) == get_proper_type(a) + for (t, a) in reversed(TypeState.inferring) ): return [] - if isinstance(template, TypeAliasType) and template.is_recursive: + if has_recursive_types(template): # This case requires special care because it may cause infinite recursion. - TypeState._inferring.append(template) + if not has_type_vars(template): + # Return early on an empty branch. + return [] + TypeState.inferring.append((template, actual)) res = _infer_constraints(template, actual, direction) - TypeState._inferring.pop() + TypeState.inferring.pop() return res return _infer_constraints(template, actual, direction) @@ -216,13 +223,18 @@ def _infer_constraints(template: Type, actual: Type, direction: int) -> List[Con # When the template is a union, we are okay with leaving some # type variables indeterminate. This helps with some special # cases, though this isn't very principled. - return any_constraints( + result = any_constraints( [ infer_constraints_if_possible(t_item, actual, direction) for t_item in template.items ], eager=False, ) + if result: + return result + elif has_recursive_types(template) and not has_recursive_types(actual): + return handle_recursive_union(template, actual, direction) + return [] # Remaining cases are handled by ConstraintBuilderVisitor. return template.accept(ConstraintBuilderVisitor(actual, direction)) @@ -279,6 +291,19 @@ def merge_with_any(constraint: Constraint) -> Constraint: ) +def handle_recursive_union(template: UnionType, actual: Type, direction: int) -> List[Constraint]: + # This is a hack to special-case things like Union[T, Inst[T]] in recursive types. Although + # it is quite arbitrary, it is a relatively common pattern, so we should handle it well. + # This function may be called when inferring against such union resulted in different + # constraints for each item. Normally we give up in such case, but here we instead split + # the union in two parts, and try inferring sequentially. + non_type_var_items = [t for t in template.items if not isinstance(t, TypeVarType)] + type_var_items = [t for t in template.items if isinstance(t, TypeVarType)] + return infer_constraints( + UnionType.make_union(non_type_var_items), actual, direction + ) or infer_constraints(UnionType.make_union(type_var_items), actual, direction) + + def any_constraints(options: List[Optional[List[Constraint]]], eager: bool) -> List[Constraint]: """Deduce what we can from a collection of constraint lists. diff --git a/mypy/infer.py b/mypy/infer.py index 1c00d2904702..d3ad0bc19f9b 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -34,7 +34,6 @@ def infer_function_type_arguments( formal_to_actual: List[List[int]], context: ArgumentInferContext, strict: bool = True, - infer_unions: bool = False, ) -> List[Optional[Type]]: """Infer the type arguments of a generic function. @@ -56,7 +55,7 @@ def infer_function_type_arguments( # Solve constraints. type_vars = callee_type.type_var_ids() - return solve_constraints(type_vars, constraints, strict, infer_unions=infer_unions) + return solve_constraints(type_vars, constraints, strict) def infer_type_arguments( diff --git a/mypy/solve.py b/mypy/solve.py index 918308625742..90bbd5b9d3b5 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -17,13 +17,11 @@ UnionType, get_proper_type, ) +from mypy.typestate import TypeState def solve_constraints( - vars: List[TypeVarId], - constraints: List[Constraint], - strict: bool = True, - infer_unions: bool = False, + vars: List[TypeVarId], constraints: List[Constraint], strict: bool = True ) -> List[Optional[Type]]: """Solve type constraints. @@ -55,7 +53,7 @@ def solve_constraints( if bottom is None: bottom = c.target else: - if infer_unions: + if TypeState.infer_unions: # This deviates from the general mypy semantics because # recursive types are union-heavy in 95% of cases. bottom = UnionType.make_union([bottom, c.target]) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 5756c581e53a..5a8c5e38b2fa 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -145,14 +145,7 @@ def is_subtype( ), "Don't pass both context and individual flags" if TypeState.is_assumed_subtype(left, right): return True - if ( - # TODO: recursive instances like `class str(Sequence[str])` can also cause - # issues, so we also need to include them in the assumptions stack - isinstance(left, TypeAliasType) - and isinstance(right, TypeAliasType) - and left.is_recursive - and right.is_recursive - ): + if mypy.typeops.is_recursive_pair(left, right): # This case requires special care because it may cause infinite recursion. # Our view on recursive types is known under a fancy name of iso-recursive mu-types. # Roughly this means that a recursive type is defined as an alias where right hand side @@ -205,12 +198,7 @@ def is_proper_subtype( ), "Don't pass both context and individual flags" if TypeState.is_assumed_proper_subtype(left, right): return True - if ( - isinstance(left, TypeAliasType) - and isinstance(right, TypeAliasType) - and left.is_recursive - and right.is_recursive - ): + if mypy.typeops.is_recursive_pair(left, right): # Same as for non-proper subtype, see detailed comment there for explanation. with pop_on_exit(TypeState.get_assumptions(is_proper=True), left, right): return _is_subtype(left, right, subtype_context, proper_subtype=True) @@ -874,7 +862,7 @@ def visit_type_alias_type(self, left: TypeAliasType) -> bool: assert False, f"This should be never called, got {left}" -T = TypeVar("T", Instance, TypeAliasType) +T = TypeVar("T", bound=Type) @contextmanager diff --git a/mypy/typeops.py b/mypy/typeops.py index f7b14c710cc2..ef3ec1de24c9 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -63,13 +63,25 @@ def is_recursive_pair(s: Type, t: Type) -> bool: - """Is this a pair of recursive type aliases?""" - return ( - isinstance(s, TypeAliasType) - and isinstance(t, TypeAliasType) - and s.is_recursive - and t.is_recursive - ) + """Is this a pair of recursive types? + + There may be more cases, and we may be forced to use e.g. has_recursive_types() + here, but this function is called in very hot code, so we try to keep it simple + and return True only in cases we know may have problems. + """ + if isinstance(s, TypeAliasType) and s.is_recursive: + return ( + isinstance(get_proper_type(t), Instance) + or isinstance(t, TypeAliasType) + and t.is_recursive + ) + if isinstance(t, TypeAliasType) and t.is_recursive: + return ( + isinstance(get_proper_type(s), Instance) + or isinstance(s, TypeAliasType) + and s.is_recursive + ) + return False def tuple_fallback(typ: TupleType) -> Instance: @@ -81,9 +93,8 @@ def tuple_fallback(typ: TupleType) -> Instance: return typ.partial_fallback items = [] for item in typ.items: - proper_type = get_proper_type(item) - if isinstance(proper_type, UnpackType): - unpacked_type = get_proper_type(proper_type.type) + if isinstance(item, UnpackType): + unpacked_type = get_proper_type(item.type) if isinstance(unpacked_type, TypeVarTupleType): items.append(unpacked_type.upper_bound) elif isinstance(unpacked_type, TupleType): diff --git a/mypy/typestate.py b/mypy/typestate.py index 389dc9c2a358..a1d2ab972a11 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -9,7 +9,7 @@ from mypy.nodes import TypeInfo from mypy.server.trigger import make_trigger -from mypy.types import Instance, Type, TypeAliasType, get_proper_type +from mypy.types import Instance, Type, get_proper_type # Represents that the 'left' instance is a subtype of the 'right' instance SubtypeRelationship: _TypeAlias = Tuple[Instance, Instance] @@ -80,10 +80,12 @@ class TypeState: # recursive type aliases. Normally, one would pass type assumptions as an additional # arguments to is_subtype(), but this would mean updating dozens of related functions # threading this through all callsites (see also comment for TypeInfo.assuming). - _assuming: Final[List[Tuple[TypeAliasType, TypeAliasType]]] = [] - _assuming_proper: Final[List[Tuple[TypeAliasType, TypeAliasType]]] = [] + _assuming: Final[List[Tuple[Type, Type]]] = [] + _assuming_proper: Final[List[Tuple[Type, Type]]] = [] # Ditto for inference of generic constraints against recursive type aliases. - _inferring: Final[List[TypeAliasType]] = [] + inferring: Final[List[Tuple[Type, Type]]] = [] + # Whether to use joins or unions when solving constraints, see checkexpr.py for details. + infer_unions: ClassVar = False # N.B: We do all of the accesses to these properties through # TypeState, instead of making these classmethods and accessing @@ -109,7 +111,7 @@ def is_assumed_proper_subtype(left: Type, right: Type) -> bool: return False @staticmethod - def get_assumptions(is_proper: bool) -> List[Tuple[TypeAliasType, TypeAliasType]]: + def get_assumptions(is_proper: bool) -> List[Tuple[Type, Type]]: if is_proper: return TypeState._assuming_proper return TypeState._assuming diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index ac2065c55f18..04b7d634d4a9 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -60,6 +60,22 @@ x: Nested[int] = [1, [2, [3]]] x = [1, [Bad()]] # E: List item 0 has incompatible type "Bad"; expected "Union[int, Nested[int]]" [builtins fixtures/isinstancelist.pyi] +[case testRecursiveAliasGenericInferenceNested] +# flags: --enable-recursive-aliases +from typing import Union, TypeVar, Sequence, List + +T = TypeVar("T") +class A: ... +class B(A): ... + +Nested = Sequence[Union[T, Nested[T]]] + +def flatten(arg: Nested[T]) -> List[T]: ... +reveal_type(flatten([[B(), B()]])) # N: Revealed type is "builtins.list[__main__.B]" +reveal_type(flatten([[[[B()]]]])) # N: Revealed type is "builtins.list[__main__.B]" +reveal_type(flatten([[B(), [[B()]]]])) # N: Revealed type is "builtins.list[__main__.B]" +[builtins fixtures/isinstancelist.pyi] + [case testRecursiveAliasNewStyleSupported] # flags: --enable-recursive-aliases from test import A @@ -278,3 +294,97 @@ if isinstance(b[0], Sequence): a = b[0] x = a # E: Incompatible types in assignment (expression has type "Sequence[Union[B, NestedB]]", variable has type "int") [builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasWithRecursiveInstance] +# flags: --enable-recursive-aliases +from typing import Sequence, Union, TypeVar + +class A: ... +T = TypeVar("T") +Nested = Sequence[Union[T, Nested[T]]] +class B(Sequence[B]): ... + +a: Nested[A] +aa: Nested[A] +b: B +a = b # OK +a = [[b]] # OK +b = aa # E: Incompatible types in assignment (expression has type "Nested[A]", variable has type "B") + +def join(a: T, b: T) -> T: ... +reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" +reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" +[builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasWithRecursiveInstanceInference] +# flags: --enable-recursive-aliases +from typing import Sequence, Union, TypeVar, List + +T = TypeVar("T") +Nested = Sequence[Union[T, Nested[T]]] +class B(Sequence[B]): ... + +nb: Nested[B] = [B(), [B(), [B()]]] +lb: List[B] + +def foo(x: Nested[T]) -> T: ... +reveal_type(foo(lb)) # N: Revealed type is "__main__.B" +reveal_type(foo([B(), [B(), [B()]]])) # N: Revealed type is "__main__.B" + +NestedInv = List[Union[T, NestedInv[T]]] +nib: NestedInv[B] = [B(), [B(), [B()]]] +def bar(x: NestedInv[T]) -> T: ... +reveal_type(bar(nib)) # N: Revealed type is "__main__.B" +[builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasTopUnion] +# flags: --enable-recursive-aliases +from typing import Sequence, Union, TypeVar, List + +class A: ... +class B(A): ... + +T = TypeVar("T") +PlainNested = Union[T, Sequence[PlainNested[T]]] + +x: PlainNested[A] +y: PlainNested[B] = [B(), [B(), [B()]]] +x = y # OK + +xx: PlainNested[B] +yy: PlainNested[A] +xx = yy # E: Incompatible types in assignment (expression has type "PlainNested[A]", variable has type "PlainNested[B]") + +def foo(arg: PlainNested[T]) -> T: ... +lb: List[B] +reveal_type(foo([B(), [B(), [B()]]])) # N: Revealed type is "__main__.B" +reveal_type(foo(lb)) # N: Revealed type is "__main__.B" +reveal_type(foo(xx)) # N: Revealed type is "__main__.B" +[builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasInferenceExplicitNonRecursive] +# flags: --enable-recursive-aliases +from typing import Sequence, Union, TypeVar, List + +T = TypeVar("T") +Nested = Sequence[Union[T, Nested[T]]] +PlainNested = Union[T, Sequence[PlainNested[T]]] + +def foo(x: Nested[T]) -> T: ... +def bar(x: PlainNested[T]) -> T: ... + +class A: ... +a: A +la: List[A] +lla: List[Union[A, List[A]]] +llla: List[Union[A, List[Union[A, List[A]]]]] + +reveal_type(foo(la)) # N: Revealed type is "__main__.A" +reveal_type(foo(lla)) # N: Revealed type is "__main__.A" +reveal_type(foo(llla)) # N: Revealed type is "__main__.A" + +reveal_type(bar(a)) # N: Revealed type is "__main__.A" +reveal_type(bar(la)) # N: Revealed type is "__main__.A" +reveal_type(bar(lla)) # N: Revealed type is "__main__.A" +reveal_type(bar(llla)) # N: Revealed type is "__main__.A" +[builtins fixtures/isinstancelist.pyi]