From 08a88154567964c4079b2657e7e017a011dbbf4e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 17 May 2023 23:35:47 +0100 Subject: [PATCH 01/16] Start working on generic stuff --- mypy/checkexpr.py | 69 ++++++++++++++++++- mypy/constraints.py | 53 ++++++++++++++ mypy/infer.py | 12 +++- mypy/solve.py | 32 ++++++++- mypy/subtypes.py | 2 +- test-data/unit/check-generics.test | 42 +++++++++++ .../unit/check-parameter-specification.test | 14 ++++ 7 files changed, 216 insertions(+), 8 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index f9fbd53866da..3601cd753134 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -100,6 +100,7 @@ from mypy.state import state from mypy.subtypes import is_equivalent, is_same_type, is_subtype, non_method_protocol_members from mypy.traverser import has_await_expression +from mypy.type_visitor import TypeTranslator from mypy.typeanal import ( check_for_explicit_any, has_any_from_unimported_type, @@ -120,7 +121,7 @@ true_only, try_expanding_sum_type_to_union, try_getting_str_literals, - tuple_fallback, + tuple_fallback, get_type_vars, ) from mypy.types import ( LITERAL_TYPE_NAMES, @@ -155,7 +156,7 @@ get_proper_type, get_proper_types, has_recursive_types, - is_named_instance, + is_named_instance, TypeVarLikeType, ) from mypy.types_utils import is_generic_instance, is_optional, is_self_type_like, remove_optional from mypy.typestate import type_state @@ -1789,6 +1790,28 @@ def infer_function_type_arguments( inferred_args[0] = self.named_type("builtins.str") elif not first_arg or not is_subtype(self.named_type("builtins.str"), first_arg): self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context) + + # TODO: Filter away ParamSpec + if any(a is None or isinstance(a, UninhabitedType) for a in inferred_args): + poly_inferred_args = infer_function_type_arguments( + callee_type, + arg_types, + arg_kinds, + formal_to_actual, + context=self.argument_infer_context(), + strict=self.chk.in_checked_function(), + allow_polymorphic=True, + ) + for i, arg in enumerate(get_proper_types(poly_inferred_args)): + if isinstance(arg, (NoneType, UninhabitedType)) or has_erased_component(arg): + poly_inferred_args[i] = None + poly_callee_type = self.apply_generic_arguments(callee_type, poly_inferred_args, context) + yes_vars = poly_callee_type.variables + no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables} + if not set(get_type_vars(poly_callee_type)) & no_vars: + applied = apply_poly(poly_callee_type, yes_vars) + if applied is not None: + return applied else: # In dynamically typed functions use implicit 'Any' types for # type variables. @@ -5290,6 +5313,48 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl return c.copy_modified(ret_type=new_ret_type) +def apply_poly(tp: CallableType, poly_tvars: list[TypeVarLikeType]) -> Optional[CallableType]: + try: + return tp.copy_modified( + arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types], + ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)), + variables=[], + ) + except PolyTranslationError: + return None + + +class PolyTranslationError(TypeError): + pass + + +class PolyTranslator(TypeTranslator): + def __init__(self, poly_tvars: list[TypeVarLikeType]) -> None: + self.poly_tvars = set(poly_tvars) + self.bound_tvars = set() + + def visit_callable_type(self, t: CallableType) -> Type: + found_vars = set() + for arg in t.arg_types: + found_vars |= set(get_type_vars(arg)) + found_vars &= self.poly_tvars + found_vars -= self.bound_tvars + self.bound_tvars |= found_vars + result = super().visit_callable_type(t) + self.bound_tvars -= found_vars + assert isinstance(result, CallableType) + result.variables += list(found_vars) + return result + + def visit_type_var(self, t: TypeVarType) -> Type: + if t in self.poly_tvars and t not in self.bound_tvars: + raise PolyTranslationError() + return super().visit_type_var(t) + + def visit_type_alias_type(self, t: TypeAliasType) -> Type: + return t.copy_modified(args=[a.accept(self) for a in t.args]) + + class ArgInferSecondPassQuery(types.BoolTypeQuery): """Query whether an argument type should be inferred in the second pass. diff --git a/mypy/constraints.py b/mypy/constraints.py index 9a662f1004f7..794c03a62e4d 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -11,6 +11,7 @@ from mypy.erasetype import erase_typevars from mypy.maptype import map_instance_to_supertype from mypy.nodes import ARG_OPT, ARG_POS, CONTRAVARIANT, COVARIANT, ArgKind +from mypy.type_visitor import BoolTypeQuery, ANY_STRATEGY from mypy.types import ( TUPLE_LIKE_INSTANCE_NAMES, AnyType, @@ -63,6 +64,48 @@ SUPERTYPE_OF: Final = 1 +def flatten_types(tls: list[list[Type]]) -> list[Type]: + res = [] + for tl in tls: + res.extend(tl) + return res + + +class PolyExtractor(TypeQuery[list[TypeVarLikeType]]): + def __init__(self) -> None: + super().__init__(flatten_types) + + def visit_callable_type(self, t: CallableType) -> list[TypeVarLikeType]: + return t.variables + super().visit_callable_type(t) + + +class PolyLeakDetector(BoolTypeQuery): + def __init__(self, found: set[TypeVarLikeType]) -> None: + super().__init__(ANY_STRATEGY) + self.bound = set() + self.found = found + + def visit_callable_type(self, t: CallableType) -> bool: + self.bound |= set(t.variables) + result = super().visit_callable_type(t) + self.bound -= set(t.variables) + return result + + def visit_type_var(self, t: TypeVarType) -> bool: + return t in self.found and t not in self.bound + + +def sanitize_constraints(constraints: list[Constraint], types: list[Type]) -> list[Constraint]: + res = [] + found = set() + for tp in types: + found |= set(tp.accept(PolyExtractor())) + for c in constraints: + if not c.target.accept(PolyLeakDetector(found)): + res.append(c) + return res + + class Constraint: """A representation of a type constraint. @@ -168,6 +211,9 @@ def infer_constraints_for_callable( actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i] ) c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF) + p_arg = get_proper_type(callee.arg_types[i]) + if not isinstance(p_arg, CallableType) or p_arg.param_spec() is None: + c = sanitize_constraints(c, [callee.arg_types[i], actual_type]) constraints.extend(c) return constraints @@ -887,6 +933,13 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: if isinstance(self.actual, CallableType): res: list[Constraint] = [] cactual = self.actual.with_unpacked_kwargs() + if cactual.variables and self.direction == SUPERTYPE_OF and template.param_spec() is None: + from mypy.subtypes import unify_generic_callable + + unified = unify_generic_callable(cactual, template, ignore_return=True) + if unified is not None: + cactual = unified + res.extend(infer_constraints(cactual, template, neg_op(self.direction))) param_spec = template.param_spec() if param_spec is None: # FIX verify argument counts diff --git a/mypy/infer.py b/mypy/infer.py index fbec3d7c4278..925c6ad267a5 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -8,11 +8,13 @@ SUBTYPE_OF, SUPERTYPE_OF, infer_constraints, - infer_constraints_for_callable, + infer_constraints_for_callable, Constraint, sanitize_constraints, ) from mypy.nodes import ArgKind from mypy.solve import solve_constraints -from mypy.types import CallableType, Instance, Type, TypeVarId +from mypy.type_visitor import TypeQuery +from mypy.typeops import get_type_vars +from mypy.types import CallableType, Instance, Type, TypeVarId, TypeVarLikeType, ParamSpecType, get_proper_type class ArgumentInferContext(NamedTuple): @@ -36,6 +38,7 @@ def infer_function_type_arguments( formal_to_actual: list[list[int]], context: ArgumentInferContext, strict: bool = True, + allow_polymorphic: bool = False, ) -> list[Type | None]: """Infer the type arguments of a generic function. @@ -57,7 +60,7 @@ def infer_function_type_arguments( # Solve constraints. type_vars = callee_type.type_var_ids() - return solve_constraints(type_vars, constraints, strict) + return solve_constraints(type_vars, constraints, strict, allow_polymorphic) def infer_type_arguments( @@ -66,4 +69,7 @@ def infer_type_arguments( # Like infer_function_type_arguments, but only match a single type # against a generic type. constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF) + p_template = get_proper_type(template) + if not isinstance(p_template, CallableType) or p_template.param_spec() is None: + constraints = sanitize_constraints(constraints, [template, actual]) return solve_constraints(type_var_ids, constraints) diff --git a/mypy/solve.py b/mypy/solve.py index b8304d29c1ce..fe7ac78da986 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -4,10 +4,12 @@ from collections import defaultdict -from mypy.constraints import SUPERTYPE_OF, Constraint +from mypy.constraints import SUPERTYPE_OF, Constraint, neg_op from mypy.join import join_types from mypy.meet import meet_types from mypy.subtypes import is_subtype +from mypy.typeanal import remove_dups +from mypy.typeops import get_type_vars from mypy.types import ( AnyType, ProperType, @@ -17,12 +19,27 @@ UninhabitedType, UnionType, get_proper_type, + ParamSpecType, + TypeVarType, ) from mypy.typestate import type_state +def remove_mirror(constraints: list[Constraint]) -> list[Constraint]: + seen = set() + result = [] + for c in constraints: + if isinstance(c.target, TypeVarType): + if (c.target.id, neg_op(c.op), c.type_var) in seen: + continue + seen.add((c.type_var, c.op, c.target.id)) + result.append(c) + return result + + def solve_constraints( - vars: list[TypeVarId], constraints: list[Constraint], strict: bool = True + vars: list[TypeVarId], constraints: list[Constraint], strict: bool = True, + allow_polymorphic: bool = False, ) -> list[Type | None]: """Solve type constraints. @@ -33,12 +50,19 @@ def solve_constraints( pick NoneType as the value of the type variable. If strict=False, pick AnyType. """ + constraints = remove_dups(constraints) + constraints = remove_mirror(constraints) + # Collect a list of constraints for each type variable. cmap: dict[TypeVarId, list[Constraint]] = defaultdict(list) for con in constraints: cmap[con.type_var].append(con) res: list[Type | None] = [] + if allow_polymorphic: + extra: set[TypeVarId] = set() + else: + extra = set(vars) # Solve each type variable separately. for tvar in vars: @@ -50,6 +74,10 @@ def solve_constraints( # bounds based on constraints. Note that we assume that the constraint # targets do not have constraint references. for c in cmap.get(tvar, []): + if set(t.id for t in get_type_vars(c.target)) & ({tvar} | extra): + if not isinstance(c.origin_type_var, ParamSpecType): + # TODO: figure out def [U] (U) -> U vs itself + continue if c.op == SUPERTYPE_OF: if bottom is None: bottom = c.target diff --git a/mypy/subtypes.py b/mypy/subtypes.py index b26aee1a92af..f509e75ce28b 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -8,7 +8,7 @@ import mypy.constraints import mypy.typeops from mypy.erasetype import erase_type -from mypy.expandtype import expand_self_type, expand_type_by_instance +from mypy.expandtype import expand_self_type, expand_type_by_instance, freshen_function_type_vars from mypy.maptype import map_instance_to_supertype # Circular import; done in the function instead. diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 06b80be85096..9ec5f800aeb6 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2733,3 +2733,45 @@ dict1: Any dict2 = {"a": C1(), **{x: C2() for x in dict1}} reveal_type(dict2) # N: Revealed type is "builtins.dict[Any, __main__.B]" [builtins fixtures/dict.pyi] + +[case testGenericStuff] +from typing import TypeVar, Callable, List + +X = TypeVar('X') +T = TypeVar('T') + +def foo(x: Callable[[int], X]) -> List[X]: + ... +def id(x: T) -> T: + ... +y = foo(id) +reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]" +[builtins fixtures/list.pyi] + +[case testHardGenericStuff] +from typing import TypeVar, Callable, List, Sequence + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +def dec(f: Callable[[S], T]) -> Callable[[S], List[T]]: + ... +def id(x: U) -> U: + ... +g = dec(id) +reveal_type(g) # N: +reveal_type(g(42)) + +def comb(f: Callable[[T], S], g: Callable[[S], U]) -> Callable[[T], U]: ... +reveal_type(comb(id, id)) + +def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]: + def inner(x: S) -> List[T]: + return [f(x) for f in fs] + return inner + +fs = [id, id, id] +reveal_type(mix(fs)) +reveal_type(mix([id, id, id])) +[builtins fixtures/list.pyi] diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index fe66b18fbfea..f11fa8ab7f1b 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1520,3 +1520,17 @@ def identity(func: Callable[P, None]) -> Callable[P, None]: ... @identity def f(f: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ... [builtins fixtures/paramspec.pyi] + +[case testParamSpecFoo] +from typing import Callable, List, TypeVar +from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = TypeVar("T") +U = TypeVar("U") + +def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ... +def test(x: U) -> U: ... +reveal_type(dec) +reveal_type(dec(test)) +[builtins fixtures/paramspec.pyi] From eb3a1e109732191c0750493f9e69db96dd0d1da9 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 20 May 2023 00:04:20 +0100 Subject: [PATCH 02/16] Make some progress --- mypy/build.py | 106 +--------- mypy/checkexpr.py | 41 ++-- mypy/constraints.py | 66 ++---- mypy/graph_utils.py | 109 ++++++++++ mypy/infer.py | 9 +- mypy/solve.py | 195 ++++++++++++++---- mypy/subtypes.py | 2 +- mypy/test/testgraph.py | 11 +- test-data/unit/check-generics.test | 7 + .../unit/check-parameter-specification.test | 3 + 10 files changed, 323 insertions(+), 226 deletions(-) create mode 100644 mypy/graph_utils.py diff --git a/mypy/build.py b/mypy/build.py index c239afb56236..64e355d4cf88 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -31,14 +31,12 @@ Callable, ClassVar, Dict, - Iterable, Iterator, Mapping, NamedTuple, NoReturn, Sequence, TextIO, - TypeVar, ) from typing_extensions import Final, TypeAlias as _TypeAlias @@ -47,6 +45,7 @@ import mypy.semanal_main from mypy.checker import TypeChecker from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error +from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort from mypy.indirection import TypeIndirectionVisitor from mypy.messages import MessageBuilder from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable, TypeInfo @@ -3465,15 +3464,8 @@ def sorted_components( edges = {id: deps_filtered(graph, vertices, id, pri_max) for id in vertices} sccs = list(strongly_connected_components(vertices, edges)) # Topsort. - sccsmap = {id: frozenset(scc) for scc in sccs for id in scc} - data: dict[AbstractSet[str], set[AbstractSet[str]]] = {} - for scc in sccs: - deps: set[AbstractSet[str]] = set() - for id in scc: - deps.update(sccsmap[x] for x in deps_filtered(graph, vertices, id, pri_max)) - data[frozenset(scc)] = deps res = [] - for ready in topsort(data): + for ready in topsort(prepare_sccs(sccs, edges)): # Sort the sets in ready by reversed smallest State.order. Examples: # # - If ready is [{x}, {y}], x.order == 1, y.order == 2, we get @@ -3498,100 +3490,6 @@ def deps_filtered(graph: Graph, vertices: AbstractSet[str], id: str, pri_max: in ] -def strongly_connected_components( - vertices: AbstractSet[str], edges: dict[str, list[str]] -) -> Iterator[set[str]]: - """Compute Strongly Connected Components of a directed graph. - - Args: - vertices: the labels for the vertices - edges: for each vertex, gives the target vertices of its outgoing edges - - Returns: - An iterator yielding strongly connected components, each - represented as a set of vertices. Each input vertex will occur - exactly once; vertices not part of a SCC are returned as - singleton sets. - - From https://code.activestate.com/recipes/578507/. - """ - identified: set[str] = set() - stack: list[str] = [] - index: dict[str, int] = {} - boundaries: list[int] = [] - - def dfs(v: str) -> Iterator[set[str]]: - index[v] = len(stack) - stack.append(v) - boundaries.append(index[v]) - - for w in edges[v]: - if w not in index: - yield from dfs(w) - elif w not in identified: - while index[w] < boundaries[-1]: - boundaries.pop() - - if boundaries[-1] == index[v]: - boundaries.pop() - scc = set(stack[index[v] :]) - del stack[index[v] :] - identified.update(scc) - yield scc - - for v in vertices: - if v not in index: - yield from dfs(v) - - -T = TypeVar("T") - - -def topsort(data: dict[T, set[T]]) -> Iterable[set[T]]: - """Topological sort. - - Args: - data: A map from vertices to all vertices that it has an edge - connecting it to. NOTE: This data structure - is modified in place -- for normalization purposes, - self-dependencies are removed and entries representing - orphans are added. - - Returns: - An iterator yielding sets of vertices that have an equivalent - ordering. - - Example: - Suppose the input has the following structure: - - {A: {B, C}, B: {D}, C: {D}} - - This is normalized to: - - {A: {B, C}, B: {D}, C: {D}, D: {}} - - The algorithm will yield the following values: - - {D} - {B, C} - {A} - - From https://code.activestate.com/recipes/577413/. - """ - # TODO: Use a faster algorithm? - for k, v in data.items(): - v.discard(k) # Ignore self dependencies. - for item in set.union(*data.values()) - set(data.keys()): - data[item] = set() - while True: - ready = {item for item, dep in data.items() if not dep} - if not ready: - break - yield ready - data = {item: (dep - ready) for item, dep in data.items() if item not in ready} - assert not data, f"A cyclic dependency exists amongst {data!r}" - - def missing_stubs_file(cache_dir: str) -> str: return os.path.join(cache_dir, "missing_stubs") diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 3601cd753134..7f8fc333b9bb 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -12,7 +12,7 @@ import mypy.errorcodes as codes from mypy import applytype, erasetype, join, message_registry, nodes, operators, types from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals -from mypy.checkmember import analyze_member_access, type_object_type +from mypy.checkmember import analyze_member_access, freeze_all_type_vars, type_object_type from mypy.checkstrformat import StringFormatterChecker from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars from mypy.errors import ErrorWatcher, report_internal_error @@ -115,13 +115,14 @@ false_only, fixup_partial_type, function_type, + get_type_vars, is_literal_type_like, make_simplified_union, simple_literal_type, true_only, try_expanding_sum_type_to_union, try_getting_str_literals, - tuple_fallback, get_type_vars, + tuple_fallback, ) from mypy.types import ( LITERAL_TYPE_NAMES, @@ -147,6 +148,7 @@ TypedDictType, TypeOfAny, TypeType, + TypeVarLikeType, TypeVarTupleType, TypeVarType, UninhabitedType, @@ -156,7 +158,7 @@ get_proper_type, get_proper_types, has_recursive_types, - is_named_instance, TypeVarLikeType, + is_named_instance, ) from mypy.types_utils import is_generic_instance, is_optional, is_self_type_like, remove_optional from mypy.typestate import type_state @@ -1791,8 +1793,10 @@ def infer_function_type_arguments( elif not first_arg or not is_subtype(self.named_type("builtins.str"), first_arg): self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context) - # TODO: Filter away ParamSpec - if any(a is None or isinstance(a, UninhabitedType) for a in inferred_args): + # TODO: Filter away (or handle) ParamSpec? + if any( + a is None or isinstance(get_proper_type(a), UninhabitedType) for a in inferred_args + ): poly_inferred_args = infer_function_type_arguments( callee_type, arg_types, @@ -1802,15 +1806,21 @@ def infer_function_type_arguments( strict=self.chk.in_checked_function(), allow_polymorphic=True, ) - for i, arg in enumerate(get_proper_types(poly_inferred_args)): - if isinstance(arg, (NoneType, UninhabitedType)) or has_erased_component(arg): + for i, pa in enumerate(get_proper_types(poly_inferred_args)): + # TODO: can we be more principled here? + if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa): poly_inferred_args[i] = None - poly_callee_type = self.apply_generic_arguments(callee_type, poly_inferred_args, context) + poly_callee_type = self.apply_generic_arguments( + callee_type, poly_inferred_args, context + ) yes_vars = poly_callee_type.variables no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables} if not set(get_type_vars(poly_callee_type)) & no_vars: applied = apply_poly(poly_callee_type, yes_vars) - if applied is not None: + if applied is not None and poly_inferred_args != [None] * len( + poly_inferred_args + ): + freeze_all_type_vars(applied) return applied else: # In dynamically typed functions use implicit 'Any' types for @@ -5313,7 +5323,7 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl return c.copy_modified(ret_type=new_ret_type) -def apply_poly(tp: CallableType, poly_tvars: list[TypeVarLikeType]) -> Optional[CallableType]: +def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> Optional[CallableType]: try: return tp.copy_modified( arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types], @@ -5329,9 +5339,9 @@ class PolyTranslationError(TypeError): class PolyTranslator(TypeTranslator): - def __init__(self, poly_tvars: list[TypeVarLikeType]) -> None: + def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None: self.poly_tvars = set(poly_tvars) - self.bound_tvars = set() + self.bound_tvars: set[TypeVarLikeType] = set() def visit_callable_type(self, t: CallableType) -> Type: found_vars = set() @@ -5342,8 +5352,9 @@ def visit_callable_type(self, t: CallableType) -> Type: self.bound_tvars |= found_vars result = super().visit_callable_type(t) self.bound_tvars -= found_vars + assert isinstance(result, ProperType) assert isinstance(result, CallableType) - result.variables += list(found_vars) + result.variables = list(result.variables) + list(found_vars) return result def visit_type_var(self, t: TypeVarType) -> Type: @@ -5351,6 +5362,10 @@ def visit_type_var(self, t: TypeVarType) -> Type: raise PolyTranslationError() return super().visit_type_var(t) + def visit_param_spec(self, t: ParamSpecType) -> Type: + # TODO: more careful here (also handle TypeVarTupleType) + raise PolyTranslationError() + def visit_type_alias_type(self, t: TypeAliasType) -> Type: return t.copy_modified(args=[a.accept(self) for a in t.args]) diff --git a/mypy/constraints.py b/mypy/constraints.py index 794c03a62e4d..eb096c44c7c1 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -11,7 +11,6 @@ from mypy.erasetype import erase_typevars from mypy.maptype import map_instance_to_supertype from mypy.nodes import ARG_OPT, ARG_POS, CONTRAVARIANT, COVARIANT, ArgKind -from mypy.type_visitor import BoolTypeQuery, ANY_STRATEGY from mypy.types import ( TUPLE_LIKE_INSTANCE_NAMES, AnyType, @@ -64,48 +63,6 @@ SUPERTYPE_OF: Final = 1 -def flatten_types(tls: list[list[Type]]) -> list[Type]: - res = [] - for tl in tls: - res.extend(tl) - return res - - -class PolyExtractor(TypeQuery[list[TypeVarLikeType]]): - def __init__(self) -> None: - super().__init__(flatten_types) - - def visit_callable_type(self, t: CallableType) -> list[TypeVarLikeType]: - return t.variables + super().visit_callable_type(t) - - -class PolyLeakDetector(BoolTypeQuery): - def __init__(self, found: set[TypeVarLikeType]) -> None: - super().__init__(ANY_STRATEGY) - self.bound = set() - self.found = found - - def visit_callable_type(self, t: CallableType) -> bool: - self.bound |= set(t.variables) - result = super().visit_callable_type(t) - self.bound -= set(t.variables) - return result - - def visit_type_var(self, t: TypeVarType) -> bool: - return t in self.found and t not in self.bound - - -def sanitize_constraints(constraints: list[Constraint], types: list[Type]) -> list[Constraint]: - res = [] - found = set() - for tp in types: - found |= set(tp.accept(PolyExtractor())) - for c in constraints: - if not c.target.accept(PolyLeakDetector(found)): - res.append(c) - return res - - class Constraint: """A representation of a type constraint. @@ -211,9 +168,6 @@ def infer_constraints_for_callable( actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i] ) c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF) - p_arg = get_proper_type(callee.arg_types[i]) - if not isinstance(p_arg, CallableType) or p_arg.param_spec() is None: - c = sanitize_constraints(c, [callee.arg_types[i], actual_type]) constraints.extend(c) return constraints @@ -933,17 +887,21 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: if isinstance(self.actual, CallableType): res: list[Constraint] = [] cactual = self.actual.with_unpacked_kwargs() - if cactual.variables and self.direction == SUPERTYPE_OF and template.param_spec() is None: - from mypy.subtypes import unify_generic_callable - - unified = unify_generic_callable(cactual, template, ignore_return=True) - if unified is not None: - cactual = unified - res.extend(infer_constraints(cactual, template, neg_op(self.direction))) param_spec = template.param_spec() if param_spec is None: # FIX verify argument counts - # FIX what if one of the functions is generic + # TODO: Erase template vars if generic? + if ( + cactual.variables + and self.direction == SUPERTYPE_OF + and cactual.param_spec() is None + ): + from mypy.subtypes import unify_generic_callable + + unified = unify_generic_callable(cactual, template, ignore_return=True) + if unified is not None: + cactual = unified + res.extend(infer_constraints(cactual, template, neg_op(self.direction))) # We can't infer constraints from arguments if the template is Callable[..., T] # (with literal '...'). diff --git a/mypy/graph_utils.py b/mypy/graph_utils.py new file mode 100644 index 000000000000..769cf081e080 --- /dev/null +++ b/mypy/graph_utils.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from typing import AbstractSet, Iterable, Iterator, TypeVar + +T = TypeVar("T") + + +def strongly_connected_components( + vertices: AbstractSet[T], edges: dict[T, list[T]] +) -> Iterator[set[T]]: + """Compute Strongly Connected Components of a directed graph. + + Args: + vertices: the labels for the vertices + edges: for each vertex, gives the target vertices of its outgoing edges + + Returns: + An iterator yielding strongly connected components, each + represented as a set of vertices. Each input vertex will occur + exactly once; vertices not part of a SCC are returned as + singleton sets. + + From https://code.activestate.com/recipes/578507/. + """ + identified: set[T] = set() + stack: list[T] = [] + index: dict[T, int] = {} + boundaries: list[int] = [] + + def dfs(v: T) -> Iterator[set[T]]: + index[v] = len(stack) + stack.append(v) + boundaries.append(index[v]) + + for w in edges[v]: + if w not in index: + yield from dfs(w) + elif w not in identified: + while index[w] < boundaries[-1]: + boundaries.pop() + + if boundaries[-1] == index[v]: + boundaries.pop() + scc = set(stack[index[v] :]) + del stack[index[v] :] + identified.update(scc) + yield scc + + for v in vertices: + if v not in index: + yield from dfs(v) + + +def prepare_sccs( + sccs: list[set[T]], edges: dict[T, list[T]] +) -> dict[AbstractSet[T], set[AbstractSet[T]]]: + sccsmap = {v: frozenset(scc) for scc in sccs for v in scc} + data: dict[AbstractSet[T], set[AbstractSet[T]]] = {} + for scc in sccs: + deps: set[AbstractSet[T]] = set() + for v in scc: + deps.update(sccsmap[x] for x in edges[v]) + data[frozenset(scc)] = deps + return data + + +def topsort(data: dict[T, set[T]]) -> Iterable[set[T]]: + """Topological sort. + + Args: + data: A map from vertices to all vertices that it has an edge + connecting it to. NOTE: This data structure + is modified in place -- for normalization purposes, + self-dependencies are removed and entries representing + orphans are added. + + Returns: + An iterator yielding sets of vertices that have an equivalent + ordering. + + Example: + Suppose the input has the following structure: + + {A: {B, C}, B: {D}, C: {D}} + + This is normalized to: + + {A: {B, C}, B: {D}, C: {D}, D: {}} + + The algorithm will yield the following values: + + {D} + {B, C} + {A} + + From https://code.activestate.com/recipes/577413/. + """ + # TODO: Use a faster algorithm? + for k, v in data.items(): + v.discard(k) # Ignore self dependencies. + for item in set.union(*data.values()) - set(data.keys()): + data[item] = set() + while True: + ready = {item for item, dep in data.items() if not dep} + if not ready: + break + yield ready + data = {item: (dep - ready) for item, dep in data.items() if item not in ready} + assert not data, f"A cyclic dependency exists amongst {data!r}" diff --git a/mypy/infer.py b/mypy/infer.py index 925c6ad267a5..66ca4169e2ff 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -8,13 +8,11 @@ SUBTYPE_OF, SUPERTYPE_OF, infer_constraints, - infer_constraints_for_callable, Constraint, sanitize_constraints, + infer_constraints_for_callable, ) from mypy.nodes import ArgKind from mypy.solve import solve_constraints -from mypy.type_visitor import TypeQuery -from mypy.typeops import get_type_vars -from mypy.types import CallableType, Instance, Type, TypeVarId, TypeVarLikeType, ParamSpecType, get_proper_type +from mypy.types import CallableType, Instance, Type, TypeVarId class ArgumentInferContext(NamedTuple): @@ -69,7 +67,4 @@ def infer_type_arguments( # Like infer_function_type_arguments, but only match a single type # against a generic type. constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF) - p_template = get_proper_type(template) - if not isinstance(p_template, CallableType) or p_template.param_spec() is None: - constraints = sanitize_constraints(constraints, [template, actual]) return solve_constraints(type_var_ids, constraints) diff --git a/mypy/solve.py b/mypy/solve.py index fe7ac78da986..08928b03818a 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -2,9 +2,9 @@ from __future__ import annotations -from collections import defaultdict - -from mypy.constraints import SUPERTYPE_OF, Constraint, neg_op +from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, neg_op +from mypy.expandtype import expand_type +from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort from mypy.join import join_types from mypy.meet import meet_types from mypy.subtypes import is_subtype @@ -16,29 +16,18 @@ Type, TypeOfAny, TypeVarId, + TypeVarType, UninhabitedType, UnionType, get_proper_type, - ParamSpecType, - TypeVarType, ) from mypy.typestate import type_state -def remove_mirror(constraints: list[Constraint]) -> list[Constraint]: - seen = set() - result = [] - for c in constraints: - if isinstance(c.target, TypeVarType): - if (c.target.id, neg_op(c.op), c.type_var) in seen: - continue - seen.add((c.type_var, c.op, c.target.id)) - result.append(c) - return result - - def solve_constraints( - vars: list[TypeVarId], constraints: list[Constraint], strict: bool = True, + vars: list[TypeVarId], + constraints: list[Constraint], + strict: bool = True, allow_polymorphic: bool = False, ) -> list[Type | None]: """Solve type constraints. @@ -50,20 +39,92 @@ def solve_constraints( pick NoneType as the value of the type variable. If strict=False, pick AnyType. """ - constraints = remove_dups(constraints) - constraints = remove_mirror(constraints) + if allow_polymorphic: + constraints = normalize_constraints(constraints, vars) # Collect a list of constraints for each type variable. - cmap: dict[TypeVarId, list[Constraint]] = defaultdict(list) + cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars} for con in constraints: - cmap[con.type_var].append(con) + if con.type_var in vars: + cmap[con.type_var].append(con) - res: list[Type | None] = [] if allow_polymorphic: - extra: set[TypeVarId] = set() - else: - extra = set(vars) + extra_constraints = [] + for tvar in vars: + extra_constraints.extend(propagate_constraints_for(tvar, SUBTYPE_OF, cmap)) + extra_constraints.extend(propagate_constraints_for(tvar, SUPERTYPE_OF, cmap)) + constraints += remove_dups(extra_constraints) + + # Recompute constraint map after propagating. + cmap = {tv: [] for tv in vars} + for con in constraints: + if con.type_var in vars: + cmap[con.type_var].append(con) + + solutions, remaining = solve_iteratively(vars, cmap, [] if allow_polymorphic else vars) + + # TODO: only do this if we have actual non-trivial constraints, not just unconstrained vars. + if remaining and allow_polymorphic: + # TODO: factor this out into separate function. + rest_cmap = { + tv: [c for c in cs if get_vars(c.target, remaining)] + for (tv, cs) in cmap.items() + if tv in remaining + } + dmap = compute_dependencies(rest_cmap) + sccs = list(strongly_connected_components(set(remaining), dmap)) + if all(check_linear(scc, rest_cmap) for scc in sccs): + leafs = next(batch for batch in topsort(prepare_sccs(sccs, dmap))) + free_vars = [] + for scc in leafs: + free_vars.append(next(tv for tv in scc)) + + solutions, _ = solve_iteratively(vars, cmap, free_vars) + for tv in free_vars: + if tv in solutions: + del solutions[tv] + + res: list[Type | None] = [] + for v in vars: + if v in solutions: + res.append(solutions[v]) + else: + # No constraints for type variable -- 'UninhabitedType' is the most specific type. + candidate: Type + if strict: + candidate = UninhabitedType() + candidate.ambiguous = True + else: + candidate = AnyType(TypeOfAny.special_form) + res.append(candidate) + return res + + +def solve_iteratively( + vars: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId] +) -> tuple[dict[TypeVarId, Type | None], list[TypeVarId]]: + solutions: dict[TypeVarId, Type | None] = {} + remaining = vars + while True: + tmap = solve_once(remaining, cmap, free_vars) + if not tmap: + break + remaining = [v for v in remaining if v not in tmap] + for v in remaining: + for c in cmap[v]: + # TODO: handle bound violations etc. + # TODO: limit number of expands by only including *new* things in tmap. + c.target = expand_type( + c.target, {k: v for (k, v) in tmap.items() if v is not None} + ) + solutions.update(tmap) + return solutions, remaining + +def solve_once( + vars: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId] +) -> dict[TypeVarId, Type | None]: + res: dict[TypeVarId, Type | None] = {} # Solve each type variable separately. for tvar in vars: bottom: Type | None = None @@ -74,10 +135,8 @@ def solve_constraints( # bounds based on constraints. Note that we assume that the constraint # targets do not have constraint references. for c in cmap.get(tvar, []): - if set(t.id for t in get_type_vars(c.target)) & ({tvar} | extra): - if not isinstance(c.origin_type_var, ParamSpecType): - # TODO: figure out def [U] (U) -> U vs itself - continue + if get_vars(c.target, [v for v in vars if v not in free_vars]): + continue if c.op == SUPERTYPE_OF: if bottom is None: bottom = c.target @@ -99,24 +158,84 @@ def solve_constraints( if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType): source_any = top if isinstance(p_top, AnyType) else bottom assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType) - res.append(AnyType(TypeOfAny.from_another_any, source_any=source_any)) + res[tvar] = AnyType(TypeOfAny.from_another_any, source_any=source_any) continue elif bottom is None: if top: candidate = top else: - # No constraints for type variable -- 'UninhabitedType' is the most specific type. - if strict: - candidate = UninhabitedType() - candidate.ambiguous = True - else: - candidate = AnyType(TypeOfAny.special_form) + # No constraints for type variable + continue elif top is None: candidate = bottom elif is_subtype(bottom, top): candidate = bottom else: candidate = None - res.append(candidate) + res[tvar] = candidate + return res + +def normalize_constraints( + constraints: list[Constraint], vars: list[TypeVarId] +) -> list[Constraint]: + res = constraints.copy() + for c in constraints: + if isinstance(c.target, TypeVarType): + res.append(Constraint(c.target, neg_op(c.op), c.origin_type_var)) + return [c for c in remove_dups(constraints) if c.type_var in vars] + + +def propagate_constraints_for( + var: TypeVarId, direction: int, cmap: dict[TypeVarId, list[Constraint]] +) -> list[Constraint]: + extra_constraints = [] + seen = set() + front = [var] + if cmap[var]: + var_def = cmap[var][0].origin_type_var + else: + return [] + while front: + tv = front.pop(0) + for c in cmap[tv]: + if ( + isinstance(c.target, TypeVarType) + and c.target.id not in seen + and c.target.id in cmap + and c.op == direction + ): + front.append(c.target.id) + seen.add(c.target.id) + elif c.op == direction: + new_c = Constraint(var_def, direction, c.target) + if new_c not in cmap[var]: + extra_constraints.append(new_c) + return extra_constraints + + +def compute_dependencies( + cmap: dict[TypeVarId, list[Constraint]] +) -> dict[TypeVarId, list[TypeVarId]]: + res = {} + vars = list(cmap.keys()) + for tv in cmap: + deps = set() + for c in cmap[tv]: + deps |= get_vars(c.target, vars) + res[tv] = list(deps) return res + + +def check_linear(scc: set[TypeVarId], cmap: dict[TypeVarId, list[Constraint]]) -> bool: + for tv in scc: + if any( + get_vars(c.target, list(scc)) and not isinstance(c.target, TypeVarType) + for c in cmap[tv] + ): + return False + return True + + +def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]: + return {tv.id for tv in get_type_vars(target)} & set(vars) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index f509e75ce28b..b26aee1a92af 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -8,7 +8,7 @@ import mypy.constraints import mypy.typeops from mypy.erasetype import erase_type -from mypy.expandtype import expand_self_type, expand_type_by_instance, freshen_function_type_vars +from mypy.expandtype import expand_self_type, expand_type_by_instance from mypy.maptype import map_instance_to_supertype # Circular import; done in the function instead. diff --git a/mypy/test/testgraph.py b/mypy/test/testgraph.py index ce7697142ff2..b0d148d5ae9c 100644 --- a/mypy/test/testgraph.py +++ b/mypy/test/testgraph.py @@ -5,17 +5,10 @@ import sys from typing import AbstractSet -from mypy.build import ( - BuildManager, - BuildSourceSet, - State, - order_ascc, - sorted_components, - strongly_connected_components, - topsort, -) +from mypy.build import BuildManager, BuildSourceSet, State, order_ascc, sorted_components from mypy.errors import Errors from mypy.fscache import FileSystemCache +from mypy.graph_utils import strongly_connected_components, topsort from mypy.modulefinder import SearchPaths from mypy.options import Options from mypy.plugin import Plugin diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 9ec5f800aeb6..472db525ac7e 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2746,6 +2746,9 @@ def id(x: T) -> T: ... y = foo(id) reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]" + +def chain(f: Callable[[X], T], g: Callable[[T], int]) -> Callable[[X], int]: ... +reveal_type(chain(id, id)) [builtins fixtures/list.pyi] [case testHardGenericStuff] @@ -2774,4 +2777,8 @@ def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]: fs = [id, id, id] reveal_type(mix(fs)) reveal_type(mix([id, id, id])) + +x: Callable[[T], T] +y: Callable[[U], U] +x = y [builtins fixtures/list.pyi] diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index f11fa8ab7f1b..f71964cb131b 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1030,6 +1030,7 @@ class Job(Generic[_P]): def __init__(self, target: Callable[_P, None]) -> None: ... def into_callable(self) -> Callable[_P, None]: ... +# TODO: add a test with return T: wellcome forall types! def generic_f(x: _T) -> None: ... j = Job(generic_f) @@ -1530,6 +1531,8 @@ T = TypeVar("T") U = TypeVar("U") def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ... +# TODO: challenge: support def dec(f: Callable[P, List[T]]) -> Callable[P, T]: ... +# TODO: support currying examples from reviewed issue. def test(x: U) -> U: ... reveal_type(dec) reveal_type(dec(test)) From 52209c9afa9d2b7a4e2e6ee58f2c40cd5a7403cd Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 22 May 2023 15:17:28 +0100 Subject: [PATCH 03/16] Add/reorganize tests --- mypy/constraints.py | 6 +- test-data/unit/check-generics.test | 78 +++++++++++++++---- .../unit/check-parameter-specification.test | 57 +++++++++++++- 3 files changed, 116 insertions(+), 25 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index eb096c44c7c1..9879742a5267 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -891,11 +891,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: if param_spec is None: # FIX verify argument counts # TODO: Erase template vars if generic? - if ( - cactual.variables - and self.direction == SUPERTYPE_OF - and cactual.param_spec() is None - ): + if cactual.variables and cactual.param_spec() is None: from mypy.subtypes import unify_generic_callable unified = unify_generic_callable(cactual, template, ignore_return=True) diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 472db525ac7e..877aeec83473 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2734,7 +2734,7 @@ dict2 = {"a": C1(), **{x: C2() for x in dict1}} reveal_type(dict2) # N: Revealed type is "builtins.dict[Any, __main__.B]" [builtins fixtures/dict.pyi] -[case testGenericStuff] +[case testInferenceAgainstGenericCallable] from typing import TypeVar, Callable, List X = TypeVar('X') @@ -2742,17 +2742,29 @@ T = TypeVar('T') def foo(x: Callable[[int], X]) -> List[X]: ... +def bar(x: Callable[[X], int]) -> List[X]: + ... + def id(x: T) -> T: ... -y = foo(id) -reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]" +reveal_type(foo(id)) # N: Revealed type is "builtins.list[builtins.int]" +reveal_type(bar(id)) # N: Revealed type is "builtins.list[builtins.int]" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableChain] +from typing import TypeVar, Callable, List + +X = TypeVar('X') +T = TypeVar('T') def chain(f: Callable[[X], T], g: Callable[[T], int]) -> Callable[[X], int]: ... -reveal_type(chain(id, id)) +def id(x: T) -> T: + ... +reveal_type(chain(id, id)) # N: Revealed type is "def (builtins.int) -> builtins.int" [builtins fixtures/list.pyi] -[case testHardGenericStuff] -from typing import TypeVar, Callable, List, Sequence +[case testInferenceAgainstGenericCallableGeneric] +from typing import TypeVar, Callable, List S = TypeVar('S') T = TypeVar('T') @@ -2762,23 +2774,57 @@ def dec(f: Callable[[S], T]) -> Callable[[S], List[T]]: ... def id(x: U) -> U: ... -g = dec(id) -reveal_type(g) # N: -reveal_type(g(42)) +reveal_type(dec(id)) # N: Revealed type is "def [T] (T`2) -> builtins.list[T`2]" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableGenericChain] +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') def comb(f: Callable[[T], S], g: Callable[[S], U]) -> Callable[[T], U]: ... -reveal_type(comb(id, id)) +def id(x: U) -> U: + ... +reveal_type(comb(id, id)) # N: Revealed type is "def [S] (S`2) -> S`2" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableGenericNonLinear] +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]: def inner(x: S) -> List[T]: return [f(x) for f in fs] return inner +# Errors caused by arg *name* mismatch are truly cryptic, but this is a known issue :/ +def id(__x: U) -> U: + ... fs = [id, id, id] -reveal_type(mix(fs)) -reveal_type(mix([id, id, id])) - -x: Callable[[T], T] -y: Callable[[U], U] -x = y +reveal_type(mix(fs)) # N: Revealed type is "def [T] (T`4) -> builtins.list[T`4]" +reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`5) -> builtins.list[S`5]" [builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCurry] +from typing import Callable, List, TypeVar + +S = TypeVar("S") +T = TypeVar("T") +U = TypeVar("U") +V = TypeVar("V") + +def dec1(f: Callable[[T], S]) -> Callable[[], Callable[[T], S]]: ... +def dec2(f: Callable[[T, U], S]) -> Callable[[U], Callable[[T], S]]: ... + +def test1(x: V) -> V: ... +def test2(x: V, y: V) -> V: ... + +reveal_type(dec1(test1)) # N: Revealed type is "def () -> def [S] (S`2) -> S`2" +# TODO: support this situation +reveal_type(dec2(test2)) # N: Revealed type is "def (builtins.object) -> def (builtins.object) -> builtins.object" +[builtins fixtures/paramspec.pyi] diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index f71964cb131b..a307213e50f1 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1030,7 +1030,6 @@ class Job(Generic[_P]): def __init__(self, target: Callable[_P, None]) -> None: ... def into_callable(self) -> Callable[_P, None]: ... -# TODO: add a test with return T: wellcome forall types! def generic_f(x: _T) -> None: ... j = Job(generic_f) @@ -1041,6 +1040,27 @@ reveal_type(jf) # N: Revealed type is "def [_T] (x: _T`-1)" reveal_type(jf(1)) # N: Revealed type is "None" [builtins fixtures/paramspec.pyi] +[case testGenericsInInferredParamspecReturn] +from typing import Callable, TypeVar, Generic +from typing_extensions import ParamSpec + +_P = ParamSpec("_P") +_T = TypeVar("_T") + +class Job(Generic[_P, _T]): + def __init__(self, target: Callable[_P, _T]) -> None: ... + def into_callable(self) -> Callable[_P, _T]: ... + +def generic_f(x: _T) -> _T: ... + +j = Job(generic_f) +reveal_type(j) # N: Revealed type is "__main__.Job[[x: _T`-1], _T`-1]" + +jf = j.into_callable() +reveal_type(jf) # N: Revealed type is "def [_T] (x: _T`-1) -> _T`-1" +reveal_type(jf(1)) # N: Revealed type is "builtins.int" +[builtins fixtures/paramspec.pyi] + [case testStackedConcatenateIsIllegal] from typing_extensions import Concatenate, ParamSpec from typing import Callable @@ -1522,7 +1542,7 @@ def identity(func: Callable[P, None]) -> Callable[P, None]: ... def f(f: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ... [builtins fixtures/paramspec.pyi] -[case testParamSpecFoo] +[case testParamSpecDecoratorAppliedToGeneric] from typing import Callable, List, TypeVar from typing_extensions import ParamSpec @@ -1531,9 +1551,38 @@ T = TypeVar("T") U = TypeVar("U") def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ... -# TODO: challenge: support def dec(f: Callable[P, List[T]]) -> Callable[P, T]: ... -# TODO: support currying examples from reviewed issue. def test(x: U) -> U: ... +reveal_type(dec) # N: Revealed type is "def [P, T] (f: def (*P.args, **P.kwargs) -> T`-2) -> def (*P.args, **P.kwargs) -> builtins.list[T`-2]" +reveal_type(dec(test)) # N: Revealed type is "def [U] (x: U`-1) -> builtins.list[U`-1]" +[builtins fixtures/paramspec.pyi] + +[case testParamSpecDecoratorAppliedToGenericReverse] +from typing import Callable, List, TypeVar +from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = TypeVar("T") +U = TypeVar("U") + +def dec(f: Callable[P, List[T]]) -> Callable[P, T]: ... +def test(x: U) -> U: ... +reveal_type(dec) # N: Revealed type is "def [P, T] (f: def (*P.args, **P.kwargs) -> T`-2) -> def (*P.args, **P.kwargs) -> builtins.list[T`-2]" +reveal_type(dec(test)) # N: Revealed type is "def [U] (x: U`-1) -> builtins.list[U`-1]" +[builtins fixtures/paramspec.pyi] + +[case testParamSpecCurryAppliedToGeneric] +from typing import Callable, List, TypeVar +from typing_extensions import ParamSpec, Concatenate + +P = ParamSpec("P") +S = TypeVar("S") +T = TypeVar("T") +U = TypeVar("U") + +def dec(f: Callable[Concatenate[T, P], S]) -> Callable[P, Callable[[T], S]]: ... +def test(x: U) -> U: ... +def test2(x: U, y: U) -> U: ... reveal_type(dec) reveal_type(dec(test)) +reveal_type(dec(test2)) [builtins fixtures/paramspec.pyi] From a64183d966120450597a37d5dafebc985e275cbc Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 22 May 2023 16:33:14 +0100 Subject: [PATCH 04/16] Solve in topological order --- mypy/constraints.py | 6 +-- mypy/solve.py | 102 +++++++++++++++++++++++++------------------- 2 files changed, 60 insertions(+), 48 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 9879742a5267..1339a0d9adbc 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -892,9 +892,9 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # FIX verify argument counts # TODO: Erase template vars if generic? if cactual.variables and cactual.param_spec() is None: - from mypy.subtypes import unify_generic_callable - - unified = unify_generic_callable(cactual, template, ignore_return=True) + unified = mypy.subtypes.unify_generic_callable( + cactual, template, ignore_return=True + ) if unified is not None: cactual = unified res.extend(infer_constraints(cactual, template, neg_op(self.direction))) diff --git a/mypy/solve.py b/mypy/solve.py index 08928b03818a..50fbe2ce7648 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -39,6 +39,8 @@ def solve_constraints( pick NoneType as the value of the type variable. If strict=False, pick AnyType. """ + if not vars: + return [] if allow_polymorphic: constraints = normalize_constraints(constraints, vars) @@ -49,40 +51,9 @@ def solve_constraints( cmap[con.type_var].append(con) if allow_polymorphic: - extra_constraints = [] - for tvar in vars: - extra_constraints.extend(propagate_constraints_for(tvar, SUBTYPE_OF, cmap)) - extra_constraints.extend(propagate_constraints_for(tvar, SUPERTYPE_OF, cmap)) - constraints += remove_dups(extra_constraints) - - # Recompute constraint map after propagating. - cmap = {tv: [] for tv in vars} - for con in constraints: - if con.type_var in vars: - cmap[con.type_var].append(con) - - solutions, remaining = solve_iteratively(vars, cmap, [] if allow_polymorphic else vars) - - # TODO: only do this if we have actual non-trivial constraints, not just unconstrained vars. - if remaining and allow_polymorphic: - # TODO: factor this out into separate function. - rest_cmap = { - tv: [c for c in cs if get_vars(c.target, remaining)] - for (tv, cs) in cmap.items() - if tv in remaining - } - dmap = compute_dependencies(rest_cmap) - sccs = list(strongly_connected_components(set(remaining), dmap)) - if all(check_linear(scc, rest_cmap) for scc in sccs): - leafs = next(batch for batch in topsort(prepare_sccs(sccs, dmap))) - free_vars = [] - for scc in leafs: - free_vars.append(next(tv for tv in scc)) - - solutions, _ = solve_iteratively(vars, cmap, free_vars) - for tv in free_vars: - if tv in solutions: - del solutions[tv] + solutions = solve_non_linear(vars, constraints, cmap) + else: + solutions = solve_iteratively([vars], cmap, vars) res: list[Type | None] = [] for v in vars: @@ -100,25 +71,66 @@ def solve_constraints( return res +def solve_non_linear( + vars: list[TypeVarId], constraints: list[Constraint], cmap: dict[TypeVarId, list[Constraint]] +) -> dict[TypeVarId, Type | None]: + extra_constraints = [] + for tvar in vars: + extra_constraints.extend(propagate_constraints_for(tvar, SUBTYPE_OF, cmap)) + extra_constraints.extend(propagate_constraints_for(tvar, SUPERTYPE_OF, cmap)) + constraints += remove_dups(extra_constraints) + + # Recompute constraint map after propagating. + cmap = {tv: [] for tv in vars} + for con in constraints: + if con.type_var in vars: + cmap[con.type_var].append(con) + + dmap = compute_dependencies(cmap) + sccs = list(strongly_connected_components(set(vars), dmap)) + if all(check_linear(scc, cmap) for scc in sccs): + raw_batches = list(topsort(prepare_sccs(sccs, dmap))) + leafs = raw_batches[0] + free_vars = [] + for scc in leafs: + if all( + isinstance(c.target, TypeVarType) and c.target.id in vars + for tv in scc + for c in cmap[tv] + ): + free_vars.append(next(tv for tv in scc)) + + batches = [] + for batch in raw_batches: + next_bc = [] + for scc in batch: + next_bc.extend(list(scc)) + batches.append(next_bc) + solutions = solve_iteratively(batches, cmap, free_vars) + for tv in free_vars: + if tv in solutions: + del solutions[tv] + return solutions + return {} + + def solve_iteratively( - vars: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId] -) -> tuple[dict[TypeVarId, Type | None], list[TypeVarId]]: + batches: list[list[TypeVarId]], + cmap: dict[TypeVarId, list[Constraint]], + free_vars: list[TypeVarId], +) -> dict[TypeVarId, Type | None]: solutions: dict[TypeVarId, Type | None] = {} - remaining = vars - while True: - tmap = solve_once(remaining, cmap, free_vars) + for batch in batches: + tmap = solve_once(batch, cmap, free_vars) if not tmap: - break - remaining = [v for v in remaining if v not in tmap] - for v in remaining: + continue + for v in cmap: for c in cmap[v]: - # TODO: handle bound violations etc. - # TODO: limit number of expands by only including *new* things in tmap. c.target = expand_type( c.target, {k: v for (k, v) in tmap.items() if v is not None} ) solutions.update(tmap) - return solutions, remaining + return solutions def solve_once( From 9191761b642cd3ff7b596abf5875a5be7b10c690 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 22 May 2023 16:41:54 +0100 Subject: [PATCH 05/16] Delete tests that are not going to be supported --- mypy/checkexpr.py | 8 +++-- mypy/constraints.py | 2 +- .../unit/check-parameter-specification.test | 31 ------------------- 3 files changed, 6 insertions(+), 35 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 7f8fc333b9bb..809b2d9d6f43 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1793,7 +1793,6 @@ def infer_function_type_arguments( elif not first_arg or not is_subtype(self.named_type("builtins.str"), first_arg): self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context) - # TODO: Filter away (or handle) ParamSpec? if any( a is None or isinstance(get_proper_type(a), UninhabitedType) for a in inferred_args ): @@ -1807,7 +1806,6 @@ def infer_function_type_arguments( allow_polymorphic=True, ) for i, pa in enumerate(get_proper_types(poly_inferred_args)): - # TODO: can we be more principled here? if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa): poly_inferred_args[i] = None poly_callee_type = self.apply_generic_arguments( @@ -5363,7 +5361,11 @@ def visit_type_var(self, t: TypeVarType) -> Type: return super().visit_type_var(t) def visit_param_spec(self, t: ParamSpecType) -> Type: - # TODO: more careful here (also handle TypeVarTupleType) + # TODO: Support polymorphic apply for ParamSpec. + raise PolyTranslationError() + + def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: + # TODO: Support polymorphic apply for TypeVarTuple. raise PolyTranslationError() def visit_type_alias_type(self, t: TypeAliasType) -> Type: diff --git a/mypy/constraints.py b/mypy/constraints.py index 1339a0d9adbc..00dd6045ff3c 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -890,7 +890,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: param_spec = template.param_spec() if param_spec is None: # FIX verify argument counts - # TODO: Erase template vars if generic? + # TODO: Erase template variables if it is generic? if cactual.variables and cactual.param_spec() is None: unified = mypy.subtypes.unify_generic_callable( cactual, template, ignore_return=True diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index a307213e50f1..eb056b02596d 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1555,34 +1555,3 @@ def test(x: U) -> U: ... reveal_type(dec) # N: Revealed type is "def [P, T] (f: def (*P.args, **P.kwargs) -> T`-2) -> def (*P.args, **P.kwargs) -> builtins.list[T`-2]" reveal_type(dec(test)) # N: Revealed type is "def [U] (x: U`-1) -> builtins.list[U`-1]" [builtins fixtures/paramspec.pyi] - -[case testParamSpecDecoratorAppliedToGenericReverse] -from typing import Callable, List, TypeVar -from typing_extensions import ParamSpec - -P = ParamSpec("P") -T = TypeVar("T") -U = TypeVar("U") - -def dec(f: Callable[P, List[T]]) -> Callable[P, T]: ... -def test(x: U) -> U: ... -reveal_type(dec) # N: Revealed type is "def [P, T] (f: def (*P.args, **P.kwargs) -> T`-2) -> def (*P.args, **P.kwargs) -> builtins.list[T`-2]" -reveal_type(dec(test)) # N: Revealed type is "def [U] (x: U`-1) -> builtins.list[U`-1]" -[builtins fixtures/paramspec.pyi] - -[case testParamSpecCurryAppliedToGeneric] -from typing import Callable, List, TypeVar -from typing_extensions import ParamSpec, Concatenate - -P = ParamSpec("P") -S = TypeVar("S") -T = TypeVar("T") -U = TypeVar("U") - -def dec(f: Callable[Concatenate[T, P], S]) -> Callable[P, Callable[[T], S]]: ... -def test(x: U) -> U: ... -def test2(x: U, y: U) -> U: ... -reveal_type(dec) -reveal_type(dec(test)) -reveal_type(dec(test2)) -[builtins fixtures/paramspec.pyi] From ba0f2526d450daffc082ba6204af54fed81b25e5 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 22 May 2023 18:59:55 +0100 Subject: [PATCH 06/16] Support type aliases; more tests --- mypy/checkexpr.py | 21 +++++++-- test-data/unit/check-generics.test | 60 ++++++++++++++++++++++++++ test-data/unit/check-plugin-attrs.test | 2 +- test-data/unit/pythoneval.test | 13 +++++- 4 files changed, 91 insertions(+), 5 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 809b2d9d6f43..192c58ab2598 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1794,7 +1794,10 @@ def infer_function_type_arguments( self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context) if any( - a is None or isinstance(get_proper_type(a), UninhabitedType) for a in inferred_args + a is None + or isinstance(get_proper_type(a), UninhabitedType) + or set(get_type_vars(a)) & set(callee_type.variables) + for a in inferred_args ): poly_inferred_args = infer_function_type_arguments( callee_type, @@ -1815,11 +1818,17 @@ def infer_function_type_arguments( no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables} if not set(get_type_vars(poly_callee_type)) & no_vars: applied = apply_poly(poly_callee_type, yes_vars) - if applied is not None and poly_inferred_args != [None] * len( + if applied is not None and poly_inferred_args != [UninhabitedType()] * len( poly_inferred_args ): freeze_all_type_vars(applied) return applied + inferred_args = [ + expand_type(a, {v.id: UninhabitedType() for v in callee_type.variables}) + if a is not None + else None + for a in inferred_args + ] else: # In dynamically typed functions use implicit 'Any' types for # type variables. @@ -5369,7 +5378,13 @@ def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: raise PolyTranslationError() def visit_type_alias_type(self, t: TypeAliasType) -> Type: - return t.copy_modified(args=[a.accept(self) for a in t.args]) + if not t.args: + return t.copy_modified() + if not t.is_recursive: + return get_proper_type(t).accept(self) + # We can't handle polymorphic application for recursive generic aliases + # without risking an infinite recursion, just give up for now. + raise PolyTranslationError() class ArgInferSecondPassQuery(types.BoolTypeQuery): diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 877aeec83473..1cdbd0a95be1 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2751,6 +2751,21 @@ reveal_type(foo(id)) # N: Revealed type is "builtins.list[builtins.int]" reveal_type(bar(id)) # N: Revealed type is "builtins.list[builtins.int]" [builtins fixtures/list.pyi] +[case testInferenceAgainstGenericCallableNoLeak] +from typing import TypeVar, Callable + +T = TypeVar('T') + +def f(x: Callable[..., T]) -> T: + return x() + +def tpl(x: T) -> T: + return x + +# This is valid because of "..." +reveal_type(f(tpl)) # N: Revealed type is "Any" +[out] + [case testInferenceAgainstGenericCallableChain] from typing import TypeVar, Callable, List @@ -2777,6 +2792,34 @@ def id(x: U) -> U: reveal_type(dec(id)) # N: Revealed type is "def [T] (T`2) -> builtins.list[T`2]" [builtins fixtures/list.pyi] +[case testInferenceAgainstGenericCallableGenericReverse] +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +def dec(f: Callable[[S], List[T]]) -> Callable[[S], T]: + ... +def id(x: U) -> U: + ... +reveal_type(dec(id)) # N: Revealed type is "def [T] (builtins.list[T`2]) -> T`2" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableGenericArg] +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +def dec(f: Callable[[S], T]) -> Callable[[S], T]: + ... +def test(x: U) -> List[U]: + ... +reveal_type(dec(test)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1]" +[builtins fixtures/list.pyi] + [case testInferenceAgainstGenericCallableGenericChain] from typing import TypeVar, Callable, List @@ -2828,3 +2871,20 @@ reveal_type(dec1(test1)) # N: Revealed type is "def () -> def [S] (S`2) -> S`2" # TODO: support this situation reveal_type(dec2(test2)) # N: Revealed type is "def (builtins.object) -> def (builtins.object) -> builtins.object" [builtins fixtures/paramspec.pyi] + +[case testInferenceAgainstGenericCallableGenericAlias] +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +A = Callable[[S], T] +B = Callable[[S], List[T]] + +def dec(f: A[S, T]) -> B[S, T]: + ... +def id(x: U) -> U: + ... +reveal_type(dec(id)) # N: Revealed type is "def [T] (T`2) -> builtins.list[T`2]" +[builtins fixtures/list.pyi] diff --git a/test-data/unit/check-plugin-attrs.test b/test-data/unit/check-plugin-attrs.test index ce1d670431c7..787ed43a85fd 100644 --- a/test-data/unit/check-plugin-attrs.test +++ b/test-data/unit/check-plugin-attrs.test @@ -1173,7 +1173,7 @@ def my_factory() -> int: return 7 @attr.s class A: - x: int = attr.ib(factory=list) # E: Incompatible types in assignment (expression has type "List[T]", variable has type "int") + x: int = attr.ib(factory=list) # E: Incompatible types in assignment (expression has type "List[]", variable has type "int") y: str = attr.ib(factory=my_factory) # E: Incompatible types in assignment (expression has type "int", variable has type "str") [builtins fixtures/list.pyi] diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 034c2190dd5e..f258d78bdb03 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -862,7 +862,7 @@ _program.py:6: error: Argument 1 to "defaultdict" has incompatible type "Type[Li _program.py:9: error: Invalid index type "str" for "defaultdict[int, str]"; expected type "int" _program.py:9: error: Incompatible types in assignment (expression has type "int", target has type "str") _program.py:19: error: Argument 1 to "tst" has incompatible type "defaultdict[str, List[]]"; expected "defaultdict[int, List[]]" -_program.py:23: error: Invalid index type "str" for "MyDDict[Dict[_KT, _VT]]"; expected type "int" +_program.py:23: error: Invalid index type "str" for "MyDDict[Dict[, ]]"; expected type "int" [case testNoSubcriptionOfStdlibCollections] # flags: --python-version 3.6 @@ -1986,3 +1986,14 @@ def good9(foo1: Foo[Concatenate[int, P]], foo2: Foo[[int, str, bytes]], *args: P [out] _testStrictEqualitywithParamSpec.py:11: error: Non-overlapping equality check (left operand type: "Foo[[int]]", right operand type: "Bar[[int]]") + +[case testGenericInferenceWithTuple] +from typing import TypeVar, Callable, Tuple + +T = TypeVar("T") + +def f(x: Callable[..., T]) -> T: + return x() + +x: Tuple[str, ...] = f(tuple) +[out] From 8eb04a9856e43b27bee245e44c18f03b4a64b4a5 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 22 May 2023 19:55:17 +0100 Subject: [PATCH 07/16] Support callback protocols; more tests --- mypy/checkexpr.py | 21 ++++++++++++++++++++- test-data/unit/check-generics.test | 16 ++++++++++++++++ test-data/unit/pythoneval.test | 26 ++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 192c58ab2598..7972b9091d0e 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -98,7 +98,13 @@ ) from mypy.semanal_enum import ENUM_BASES from mypy.state import state -from mypy.subtypes import is_equivalent, is_same_type, is_subtype, non_method_protocol_members +from mypy.subtypes import ( + find_member, + is_equivalent, + is_same_type, + is_subtype, + non_method_protocol_members, +) from mypy.traverser import has_await_expression from mypy.type_visitor import TypeTranslator from mypy.typeanal import ( @@ -5349,6 +5355,7 @@ class PolyTranslator(TypeTranslator): def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None: self.poly_tvars = set(poly_tvars) self.bound_tvars: set[TypeVarLikeType] = set() + self.seen_aliases: set[TypeInfo] = set() def visit_callable_type(self, t: CallableType) -> Type: found_vars = set() @@ -5386,6 +5393,18 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type: # without risking an infinite recursion, just give up for now. raise PolyTranslationError() + def visit_instance(self, t: Instance) -> Type: + # There is the same problem with callback protocols as with aliases + # (callback protocols are essentially more flexible aliases to callables) + if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]: + if t.type in self.seen_aliases: + raise PolyTranslationError() + self.seen_aliases.add(t.type) + call = find_member("__call__", t, t, is_operator=True) + assert call is not None + return call.accept(self) + return super().visit_instance(t) + class ArgInferSecondPassQuery(types.BoolTypeQuery): """Query whether an argument type should be inferred in the second pass. diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 1cdbd0a95be1..1f9fbddf8203 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2888,3 +2888,19 @@ def id(x: U) -> U: ... reveal_type(dec(id)) # N: Revealed type is "def [T] (T`2) -> builtins.list[T`2]" [builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableGenericProtocol] +# flags: --strict-optional +from typing import TypeVar, Protocol, Generic, Optional + +_T = TypeVar('_T') + +class _F(Protocol[_T]): + def __call__(self, __x: _T) -> _T: ... + +def lift(f: _F[_T]) -> _F[Optional[_T]]: ... +def g(x: _T) -> _T: + return x + +reveal_type(lift(g)) # N: Revealed type is "def [_T] (Union[_T`1, None]) -> Union[_T`1, None]" +[builtins fixtures/list.pyi] diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index f258d78bdb03..6cf06b180a88 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1997,3 +1997,29 @@ def f(x: Callable[..., T]) -> T: x: Tuple[str, ...] = f(tuple) [out] + +[case testGenericInferenceWithDataclass] +from typing import Any, Collection, List +from dataclasses import dataclass, field + +class Foo: + pass + +@dataclass +class A: + items: Collection[Foo] = field(default_factory=list) +[out] + +[case testGenericInferenceWithItertools] +from typing import TypeVar, Tuple +from itertools import groupby +K = TypeVar("K") +V = TypeVar("V") + +def fst(kv: Tuple[K, V]) -> K: + k, v = kv + return k + +pairs = [(len(s), s) for s in ["one", "two", "three"]] +grouped = groupby(pairs, key=fst) +[out] From ec8b695dbe656bca6929be3c514567f9233edb0f Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 22 May 2023 22:23:53 +0100 Subject: [PATCH 08/16] Add some docstring/comments --- mypy/checkexpr.py | 25 ++++++++++++++++++++++++- mypy/constraints.py | 6 ++++++ mypy/graph_utils.py | 3 +++ mypy/solve.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 76 insertions(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 7972b9091d0e..2e5022546a4e 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1805,6 +1805,10 @@ def infer_function_type_arguments( or set(get_type_vars(a)) & set(callee_type.variables) for a in inferred_args ): + # If the regular two-phase inference didn't work, try inferring type + # variables while allowing for polymorphic solutions, i.e. for solutions + # potentially involving free variables. + # TODO: support the similar inference for return type context. poly_inferred_args = infer_function_type_arguments( callee_type, arg_types, @@ -1816,6 +1820,7 @@ def infer_function_type_arguments( ) for i, pa in enumerate(get_proper_types(poly_inferred_args)): if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa): + # Indicate that free variables should not be applied in the call below. poly_inferred_args[i] = None poly_callee_type = self.apply_generic_arguments( callee_type, poly_inferred_args, context @@ -1823,12 +1828,15 @@ def infer_function_type_arguments( yes_vars = poly_callee_type.variables no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables} if not set(get_type_vars(poly_callee_type)) & no_vars: + # Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can + # be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed. applied = apply_poly(poly_callee_type, yes_vars) if applied is not None and poly_inferred_args != [UninhabitedType()] * len( poly_inferred_args ): freeze_all_type_vars(applied) return applied + # If it didn't work, erase free variables as , to avoid confusing errors. inferred_args = [ expand_type(a, {v.id: UninhabitedType() for v in callee_type.variables}) if a is not None @@ -5337,6 +5345,15 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> Optional[CallableType]: + """Make free type variables generic in the type if possible. + + This will analyze the type `tp` while trying to create valid bindings for + type variables `poly_tvars` while traversing the type. This follows the same rules + as we do during semantic analysis phase, examples: + * Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T + * Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T) + * List[T] -> None (not possible) + """ try: return tp.copy_modified( arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types], @@ -5352,8 +5369,13 @@ class PolyTranslationError(TypeError): class PolyTranslator(TypeTranslator): + """Make free type variables generic in the type if possible. + + See docstring for apply_poly() for details. + """ def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None: self.poly_tvars = set(poly_tvars) + # This is a simplified version of TypeVarScope used during semantic analysis. self.bound_tvars: set[TypeVarLikeType] = set() self.seen_aliases: set[TypeInfo] = set() @@ -5395,7 +5417,8 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type: def visit_instance(self, t: Instance) -> Type: # There is the same problem with callback protocols as with aliases - # (callback protocols are essentially more flexible aliases to callables) + # (callback protocols are essentially more flexible aliases to callables). + # Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T]. if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]: if t.type in self.seen_aliases: raise PolyTranslationError() diff --git a/mypy/constraints.py b/mypy/constraints.py index 00dd6045ff3c..62109e066559 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -892,6 +892,12 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # FIX verify argument counts # TODO: Erase template variables if it is generic? if cactual.variables and cactual.param_spec() is None: + # If template is generic, unify it with template. Note: this is + # not an ideal solution (which would be adding the generic variables + # to the constraint inference set), but is a good first approximation, + # and this will prevent leaking these variables in the solutions. + # Note: this may infer constraints like T <: S or T <: List[S] + # that contain variables in the target. unified = mypy.subtypes.unify_generic_callable( cactual, template, ignore_return=True ) diff --git a/mypy/graph_utils.py b/mypy/graph_utils.py index 769cf081e080..399301a6b0fd 100644 --- a/mypy/graph_utils.py +++ b/mypy/graph_utils.py @@ -1,3 +1,5 @@ +"""Helpers for manipulations with graphs.""" + from __future__ import annotations from typing import AbstractSet, Iterable, Iterator, TypeVar @@ -54,6 +56,7 @@ def dfs(v: T) -> Iterator[set[T]]: def prepare_sccs( sccs: list[set[T]], edges: dict[T, list[T]] ) -> dict[AbstractSet[T], set[AbstractSet[T]]]: + """Use original edges to organize SCCs in a graph by dependencies between them.""" sccsmap = {v: frozenset(scc) for scc in sccs for v in scc} data: dict[AbstractSet[T], set[AbstractSet[T]]] = {} for scc in sccs: diff --git a/mypy/solve.py b/mypy/solve.py index 50fbe2ce7648..4762de90707d 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -42,6 +42,8 @@ def solve_constraints( if not vars: return [] if allow_polymorphic: + # Constraints like T :> S and S <: T are semantically the same, but they are + # represented differently. Normalize the constraint list w.r.t this equivalence. constraints = normalize_constraints(constraints, vars) # Collect a list of constraints for each type variable. @@ -74,8 +76,19 @@ def solve_constraints( def solve_non_linear( vars: list[TypeVarId], constraints: list[Constraint], cmap: dict[TypeVarId, list[Constraint]] ) -> dict[TypeVarId, Type | None]: + """Solve set of constraints that may include non-linear ones, like T <: List[S]. + + The whole algorithm consists of five steps: + * Propagate via linear constraints to get all possible constraints for each variable + * Find dependencies between type variables, group them in SCCs, and sor topologically + * Check all SCC are intrinsically linear, it is impossible to solve T <: List[T] + * Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC) + * Solve constraints iteratively starting from leafs, updating targets after each step. + """ extra_constraints = [] for tvar in vars: + # TODO: support iteratively inferring secondary constraints like + # Sequence[T] <: S <: Sequence[U] => T <: U extra_constraints.extend(propagate_constraints_for(tvar, SUBTYPE_OF, cmap)) extra_constraints.extend(propagate_constraints_for(tvar, SUPERTYPE_OF, cmap)) constraints += remove_dups(extra_constraints) @@ -100,13 +113,18 @@ def solve_non_linear( ): free_vars.append(next(tv for tv in scc)) + # Flatten the SCCs that are independent, we can solve them together, + # since we don't need to update any targets in between. batches = [] for batch in raw_batches: next_bc = [] for scc in batch: next_bc.extend(list(scc)) batches.append(next_bc) + solutions = solve_iteratively(batches, cmap, free_vars) + # We remove the solutions like T = T for free variables. This will indicate + # to the apply function, that they should not be touched. for tv in free_vars: if tv in solutions: del solutions[tv] @@ -119,6 +137,7 @@ def solve_iteratively( cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId], ) -> dict[TypeVarId, Type | None]: + """Solve constraints for type variables sequentially, updating targets after each step.""" solutions: dict[TypeVarId, Type | None] = {} for batch in batches: tmap = solve_once(batch, cmap, free_vars) @@ -136,6 +155,7 @@ def solve_iteratively( def solve_once( vars: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId] ) -> dict[TypeVarId, Type | None]: + """Solve constraints by finding by using meets of upper bounds, and joins of lower bounds.""" res: dict[TypeVarId, Type | None] = {} # Solve each type variable separately. for tvar in vars: @@ -191,6 +211,12 @@ def solve_once( def normalize_constraints( constraints: list[Constraint], vars: list[TypeVarId] ) -> list[Constraint]: + """Normalize list of constraints (to simplify life for the non-linear solver). + + This includes two things currently: + * Complement T :> S by S <: T + * Remove strict duplicates + """ res = constraints.copy() for c in constraints: if isinstance(c.target, TypeVarType): @@ -201,6 +227,13 @@ def normalize_constraints( def propagate_constraints_for( var: TypeVarId, direction: int, cmap: dict[TypeVarId, list[Constraint]] ) -> list[Constraint]: + """Propagate via linear constraints to get additional constraints for `var`. + + For example if we have constraints: + [T <: int, S <: T, S :> str] + we can add two more + [S <: int, T :> str] + """ extra_constraints = [] seen = set() front = [var] @@ -229,6 +262,11 @@ def propagate_constraints_for( def compute_dependencies( cmap: dict[TypeVarId, list[Constraint]] ) -> dict[TypeVarId, list[TypeVarId]]: + """Compute dependencies between type variables induced by constraints. + + If we have a constraint like T <: List[S], we say that T depends on S, since + we will need to solve for S first before we can solve for T. + """ res = {} vars = list(cmap.keys()) for tv in cmap: @@ -240,6 +278,10 @@ def compute_dependencies( def check_linear(scc: set[TypeVarId], cmap: dict[TypeVarId, list[Constraint]]) -> bool: + """Check there are only linear constraints between type variables in SCC. + + Linear are constraints like T <: S (while T <: F[S] are non-linear). + """ for tv in scc: if any( get_vars(c.target, list(scc)) and not isinstance(c.target, TypeVarType) @@ -250,4 +292,5 @@ def check_linear(scc: set[TypeVarId], cmap: dict[TypeVarId, list[Constraint]]) - def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]: + """Find type variables for which we are solving in a target type.""" return {tv.id for tv in get_type_vars(target)} & set(vars) From 4c41c67bfc17baf5cc9cfdd0c8a25a3de4850eb4 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 22 May 2023 23:48:50 +0100 Subject: [PATCH 09/16] Some tweaks --- mypy/checkexpr.py | 11 ++++++----- mypy/constraints.py | 4 ++-- mypy/solve.py | 10 ++++++---- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 0842fec77699..4d3c9781dc73 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5387,7 +5387,7 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> Optional[CallableType]: """Make free type variables generic in the type if possible. - This will analyze the type `tp` while trying to create valid bindings for + This will translate the type `tp` while trying to create valid bindings for type variables `poly_tvars` while traversing the type. This follows the same rules as we do during semantic analysis phase, examples: * Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T @@ -5413,6 +5413,7 @@ class PolyTranslator(TypeTranslator): See docstring for apply_poly() for details. """ + def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None: self.poly_tvars = set(poly_tvars) # This is a simplified version of TypeVarScope used during semantic analysis. @@ -5422,14 +5423,14 @@ def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None: def visit_callable_type(self, t: CallableType) -> Type: found_vars = set() for arg in t.arg_types: - found_vars |= set(get_type_vars(arg)) - found_vars &= self.poly_tvars + found_vars |= set(get_type_vars(arg)) & self.poly_tvars + found_vars -= self.bound_tvars self.bound_tvars |= found_vars result = super().visit_callable_type(t) self.bound_tvars -= found_vars - assert isinstance(result, ProperType) - assert isinstance(result, CallableType) + + assert isinstance(result, ProperType) and isinstance(result, CallableType) result.variables = list(result.variables) + list(found_vars) return result diff --git a/mypy/constraints.py b/mypy/constraints.py index 6af3f3f4fdf0..395cdbf2972a 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -888,9 +888,9 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # FIX verify argument counts # TODO: Erase template variables if it is generic? if cactual.variables and cactual.param_spec() is None: - # If template is generic, unify it with template. Note: this is + # If actual is generic, unify it with template. Note: this is # not an ideal solution (which would be adding the generic variables - # to the constraint inference set), but is a good first approximation, + # to the constraint inference set), but it's a good first approximation, # and this will prevent leaking these variables in the solutions. # Note: this may infer constraints like T <: S or T <: List[S] # that contain variables in the target. diff --git a/mypy/solve.py b/mypy/solve.py index 4762de90707d..837f36c714af 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -148,6 +148,8 @@ def solve_iteratively( c.target = expand_type( c.target, {k: v for (k, v) in tmap.items() if v is not None} ) + # TODO: support backtracking lower/upper bound choices + # (will require switching this function from iterative to recursive). solutions.update(tmap) return solutions @@ -213,10 +215,10 @@ def normalize_constraints( ) -> list[Constraint]: """Normalize list of constraints (to simplify life for the non-linear solver). - This includes two things currently: - * Complement T :> S by S <: T - * Remove strict duplicates - """ + This includes two things currently: + * Complement T :> S by S <: T + * Remove strict duplicates + """ res = constraints.copy() for c in constraints: if isinstance(c.target, TypeVarType): From 163720ca9bcf2248ee3ed8a642423e4ba9b1a789 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 26 May 2023 23:20:19 +0100 Subject: [PATCH 10/16] Minor fixes --- mypy/checkexpr.py | 2 +- mypy/semanal.py | 2 +- mypy/solve.py | 4 +++- mypy/typeanal.py | 13 ------------- mypy/types.py | 13 +++++++++++++ 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 4d3c9781dc73..d4e4081e8f53 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5404,7 +5404,7 @@ def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> Optio return None -class PolyTranslationError(TypeError): +class PolyTranslationError(Exception): pass diff --git a/mypy/semanal.py b/mypy/semanal.py index 648852fdecc8..732cacf801f2 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -234,7 +234,6 @@ fix_instance_types, has_any_from_unimported_type, no_subscript_builtin_alias, - remove_dups, type_constructors, ) from mypy.typeops import function_type, get_type_vars, try_getting_str_literals_from_type @@ -277,6 +276,7 @@ get_proper_type, get_proper_types, is_named_instance, + remove_dups, ) from mypy.types_utils import is_invalid_recursive_alias, store_argument_type from mypy.typevars import fill_typevars diff --git a/mypy/solve.py b/mypy/solve.py index 837f36c714af..05c42101b0a7 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -8,7 +8,6 @@ from mypy.join import join_types from mypy.meet import meet_types from mypy.subtypes import is_subtype -from mypy.typeanal import remove_dups from mypy.typeops import get_type_vars from mypy.types import ( AnyType, @@ -20,6 +19,7 @@ UninhabitedType, UnionType, get_proper_type, + remove_dups, ) from mypy.typestate import type_state @@ -169,6 +169,8 @@ def solve_once( # bounds based on constraints. Note that we assume that the constraint # targets do not have constraint references. for c in cmap.get(tvar, []): + # There may be multiple steps needed to solve all vars within a + # (linear) SCC. We ignore targets pointing to not yet solved vars. if get_vars(c.target, [v for v in vars if v not in free_vars]): continue if c.op == SUPERTYPE_OF: diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 95acb71b45d2..f6554abc0ab3 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -1848,19 +1848,6 @@ def set_any_tvars( return TypeAliasType(node, args, newline, newcolumn) -def remove_dups(tvars: list[T]) -> list[T]: - if len(tvars) <= 1: - return tvars - # Get unique elements in order of appearance - all_tvars: set[T] = set() - new_tvars: list[T] = [] - for t in tvars: - if t not in all_tvars: - new_tvars.append(t) - all_tvars.add(t) - return new_tvars - - def flatten_tvars(lists: list[list[T]]) -> list[T]: result: list[T] = [] for lst in lists: diff --git a/mypy/types.py b/mypy/types.py index 0e1374466341..31e210d8af4f 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -3436,6 +3436,19 @@ def callable_with_ellipsis(any_type: AnyType, ret_type: Type, fallback: Instance ) +def remove_dups(types: list[T]) -> list[T]: + if len(types) <= 1: + return types + # Get unique elements in order of appearance + all_types: set[T] = set() + new_types: list[T] = [] + for t in types: + if t not in all_types: + new_types.append(t) + all_types.add(t) + return new_types + + # This cyclic import is unfortunate, but to avoid it we would need to move away all uses # of get_proper_type() from types.py. Majority of them have been removed, but few remaining # are quite tricky to get rid of, but ultimately we want to do it at some point. From 0bc41b06ed8878e2dcfbed194c82f87b3551f098 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 5 Jun 2023 16:52:34 +0100 Subject: [PATCH 11/16] Fix bugs --- mypy/solve.py | 212 +++++++++++++++++++---------- test-data/unit/check-generics.test | 43 +++++- 2 files changed, 179 insertions(+), 76 deletions(-) diff --git a/mypy/solve.py b/mypy/solve.py index 05c42101b0a7..00fcd5e3d6d9 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Iterable + from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, neg_op from mypy.expandtype import expand_type from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort @@ -55,7 +57,13 @@ def solve_constraints( if allow_polymorphic: solutions = solve_non_linear(vars, constraints, cmap) else: - solutions = solve_iteratively([vars], cmap, vars) + solutions = {} + for tv, cs in cmap.items(): + if not cs: + continue + lowers = [c.target for c in cs if c.op == SUPERTYPE_OF] + uppers = [c.target for c in cs if c.op == SUBTYPE_OF] + solutions[tv] = solve_one(lowers, uppers, []) res: list[Type | None] = [] for v in vars: @@ -81,14 +89,12 @@ def solve_non_linear( The whole algorithm consists of five steps: * Propagate via linear constraints to get all possible constraints for each variable * Find dependencies between type variables, group them in SCCs, and sor topologically - * Check all SCC are intrinsically linear, it is impossible to solve T <: List[T] + * Check all SCC are intrinsically linear, we can't solve (express) T <: List[T] * Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC) * Solve constraints iteratively starting from leafs, updating targets after each step. """ extra_constraints = [] for tvar in vars: - # TODO: support iteratively inferring secondary constraints like - # Sequence[T] <: S <: Sequence[U] => T <: U extra_constraints.extend(propagate_constraints_for(tvar, SUBTYPE_OF, cmap)) extra_constraints.extend(propagate_constraints_for(tvar, SUPERTYPE_OF, cmap)) constraints += remove_dups(extra_constraints) @@ -111,6 +117,7 @@ def solve_non_linear( for tv in scc for c in cmap[tv] ): + # TODO: be careful about upper bounds (or values) when introducing free vars. free_vars.append(next(tv for tv in scc)) # Flatten the SCCs that are independent, we can solve them together, @@ -122,9 +129,13 @@ def solve_non_linear( next_bc.extend(list(scc)) batches.append(next_bc) - solutions = solve_iteratively(batches, cmap, free_vars) + solutions: dict[TypeVarId, Type | None] = {} + for flat_batch in batches: + solutions.update(solve_iteratively(flat_batch, cmap, free_vars)) # We remove the solutions like T = T for free variables. This will indicate # to the apply function, that they should not be touched. + # TODO: return list of free type variables explicitly, this logic is fragile + # (but if we do, we need to be careful everything works in incremental modes). for tv in free_vars: if tv in solutions: del solutions[tv] @@ -133,83 +144,97 @@ def solve_non_linear( def solve_iteratively( - batches: list[list[TypeVarId]], - cmap: dict[TypeVarId, list[Constraint]], - free_vars: list[TypeVarId], + batch: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId] ) -> dict[TypeVarId, Type | None]: """Solve constraints for type variables sequentially, updating targets after each step.""" - solutions: dict[TypeVarId, Type | None] = {} - for batch in batches: - tmap = solve_once(batch, cmap, free_vars) - if not tmap: + solutions = {} + relevant_constraints = [] + for tv in batch: + relevant_constraints.extend(cmap.get(tv, [])) + lowers, uppers = transitive_closure(batch, relevant_constraints) + s_batch = set(batch) + not_allowed_vars = [v for v in batch if v not in free_vars] + while s_batch: + for tv in s_batch: + if any(not get_vars(l, not_allowed_vars) for l in lowers[tv]) or any( + not get_vars(u, not_allowed_vars) for u in uppers[tv] + ): + solvable_tv = tv + break + else: + break + # Solve each solvable type variable separately. + s_batch.remove(solvable_tv) + result = solve_one(lowers[solvable_tv], uppers[solvable_tv], not_allowed_vars) + solutions[solvable_tv] = result + if result is None: + # TODO: support backtracking lower/upper bound choices + # (will require switching this function from iterative to recursive). continue + # Update the (transitive) constraints if there is a solution. + subs = {solvable_tv: result} + lowers = {tv: {expand_type(l, subs) for l in lowers[tv]} for tv in lowers} + uppers = {tv: {expand_type(u, subs) for u in uppers[tv]} for tv in uppers} for v in cmap: for c in cmap[v]: - c.target = expand_type( - c.target, {k: v for (k, v) in tmap.items() if v is not None} - ) - # TODO: support backtracking lower/upper bound choices - # (will require switching this function from iterative to recursive). - solutions.update(tmap) + c.target = expand_type(c.target, subs) return solutions -def solve_once( - vars: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId] -) -> dict[TypeVarId, Type | None]: +def solve_one( + lowers: Iterable[Type], uppers: Iterable[Type], not_allowed_vars: list[TypeVarId] +) -> Type | None: """Solve constraints by finding by using meets of upper bounds, and joins of lower bounds.""" - res: dict[TypeVarId, Type | None] = {} - # Solve each type variable separately. - for tvar in vars: - bottom: Type | None = None - top: Type | None = None - candidate: Type | None = None - - # Process each constraint separately, and calculate the lower and upper - # bounds based on constraints. Note that we assume that the constraint - # targets do not have constraint references. - for c in cmap.get(tvar, []): - # There may be multiple steps needed to solve all vars within a - # (linear) SCC. We ignore targets pointing to not yet solved vars. - if get_vars(c.target, [v for v in vars if v not in free_vars]): - continue - if c.op == SUPERTYPE_OF: - if bottom is None: - bottom = c.target - else: - if type_state.infer_unions: - # This deviates from the general mypy semantics because - # recursive types are union-heavy in 95% of cases. - bottom = UnionType.make_union([bottom, c.target]) - else: - bottom = join_types(bottom, c.target) - else: - if top is None: - top = c.target - else: - top = meet_types(top, c.target) - - p_top = get_proper_type(top) - p_bottom = get_proper_type(bottom) - if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType): - source_any = top if isinstance(p_top, AnyType) else bottom - assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType) - res[tvar] = AnyType(TypeOfAny.from_another_any, source_any=source_any) + bottom: Type | None = None + top: Type | None = None + candidate: Type | None = None + + # Process each bound separately, and calculate the lower and upper + # bounds based on constraints. Note that we assume that the constraint + # targets do not have constraint references. + for target in lowers: + # There may be multiple steps needed to solve all vars within a + # (linear) SCC. We ignore targets pointing to not yet solved vars. + if get_vars(target, not_allowed_vars): continue - elif bottom is None: - if top: - candidate = top + if bottom is None: + bottom = target + else: + if type_state.infer_unions: + # This deviates from the general mypy semantics because + # recursive types are union-heavy in 95% of cases. + bottom = UnionType.make_union([bottom, target]) else: - # No constraints for type variable - continue - elif top is None: - candidate = bottom - elif is_subtype(bottom, top): - candidate = bottom + bottom = join_types(bottom, target) + + for target in uppers: + # Same as above. + if get_vars(target, not_allowed_vars): + continue + if top is None: + top = target else: - candidate = None - res[tvar] = candidate - return res + top = meet_types(top, target) + + p_top = get_proper_type(top) + p_bottom = get_proper_type(bottom) + if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType): + source_any = top if isinstance(p_top, AnyType) else bottom + assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType) + return AnyType(TypeOfAny.from_another_any, source_any=source_any) + elif bottom is None: + if top: + candidate = top + else: + # No constraints for type variable + return None + elif top is None: + candidate = bottom + elif is_subtype(bottom, top): + candidate = bottom + else: + candidate = None + return candidate def normalize_constraints( @@ -263,6 +288,53 @@ def propagate_constraints_for( return extra_constraints +def transitive_closure( + tvars: list[TypeVarId], constraints: list[Constraint] +) -> tuple[dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]]]: + """Find transitive closure for given constraints on type variables. + + Transitive closure gives maximal set of lower/upper bounds for each type variable, such + we cannot deduce any further bounds by chaining other existing bounds. + """ + # TODO: merge propagate_constraints_for() into this function. + # TODO: add secondary constraints here to make the algorithm complete. + uppers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars} + lowers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars} + graph: set[tuple[TypeVarId, TypeVarId]] = set() + + # Prime the closure with the initial trivial values. + for c in constraints: + if isinstance(c.target, TypeVarType) and c.target.id in tvars: + if c.op == SUBTYPE_OF: + graph.add((c.type_var, c.target.id)) + else: + graph.add((c.target.id, c.type_var)) + if c.op == SUBTYPE_OF: + uppers[c.type_var].add(c.target) + else: + lowers[c.type_var].add(c.target) + + # At this stage we know that constant bounds have been propagated already, so we + # only need to propagate linear constraints. + for c in constraints: + if isinstance(c.target, TypeVarType) and c.target.id in tvars: + if c.op == SUBTYPE_OF: + lower, upper = c.type_var, c.target.id + else: + lower, upper = c.target.id, c.type_var + extras = { + (l, u) for l in tvars for u in tvars if (l, lower) in graph and (upper, u) in graph + } + graph |= extras + for u in tvars: + if (upper, u) in graph: + lowers[u] |= lowers[lower] + for l in tvars: + if (l, lower) in graph: + uppers[l] |= uppers[upper] + return lowers, uppers + + def compute_dependencies( cmap: dict[TypeVarId, list[Constraint]] ) -> dict[TypeVarId, list[TypeVarId]]: diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 1f9fbddf8203..3ba9e664fce0 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2893,14 +2893,45 @@ reveal_type(dec(id)) # N: Revealed type is "def [T] (T`2) -> builtins.list[T`2] # flags: --strict-optional from typing import TypeVar, Protocol, Generic, Optional -_T = TypeVar('_T') +T = TypeVar('T') -class _F(Protocol[_T]): - def __call__(self, __x: _T) -> _T: ... +class F(Protocol[T]): + def __call__(self, __x: T) -> T: ... -def lift(f: _F[_T]) -> _F[Optional[_T]]: ... -def g(x: _T) -> _T: +def lift(f: F[T]) -> F[Optional[T]]: ... +def g(x: T) -> T: return x -reveal_type(lift(g)) # N: Revealed type is "def [_T] (Union[_T`1, None]) -> Union[_T`1, None]" +reveal_type(lift(g)) # N: Revealed type is "def [T] (Union[T`1, None]) -> Union[T`1, None]" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericSplitOrder] +# flags: --strict-optional +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +def dec(f: Callable[[T], S], g: Callable[[T], int]) -> Callable[[T], List[S]]: ... +def id(x: U) -> U: + ... + +reveal_type(dec(id, id)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]" [builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericSplitOrderGeneric] +# flags: --strict-optional +from typing import TypeVar, Callable, Tuple + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') +V = TypeVar('V') + +def dec(f: Callable[[T], S], g: Callable[[T], U]) -> Callable[[T], Tuple[S, U]]: ... +def id(x: V) -> V: + ... + +reveal_type(dec(id, id)) # N: Revealed type is "def [S] (S`2) -> Tuple[S`2, S`2]" +[builtins fixtures/tuple.pyi] From fefe27ec45e68ebcefbce0ac1b02dbe30dabc56d Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 5 Jun 2023 22:54:30 +0100 Subject: [PATCH 12/16] Make free variable choice stable --- mypy/solve.py | 2 +- test-data/unit/check-generics.test | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mypy/solve.py b/mypy/solve.py index 00fcd5e3d6d9..a141f7dfc920 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -118,7 +118,7 @@ def solve_non_linear( for c in cmap[tv] ): # TODO: be careful about upper bounds (or values) when introducing free vars. - free_vars.append(next(tv for tv in scc)) + free_vars.append(sorted(scc, key=lambda x: x.raw_id)[0]) # Flatten the SCCs that are independent, we can solve them together, # since we don't need to update any targets in between. diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 3ba9e664fce0..80bdf43f92a8 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2789,7 +2789,7 @@ def dec(f: Callable[[S], T]) -> Callable[[S], List[T]]: ... def id(x: U) -> U: ... -reveal_type(dec(id)) # N: Revealed type is "def [T] (T`2) -> builtins.list[T`2]" +reveal_type(dec(id)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericReverse] @@ -2830,7 +2830,7 @@ U = TypeVar('U') def comb(f: Callable[[T], S], g: Callable[[S], U]) -> Callable[[T], U]: ... def id(x: U) -> U: ... -reveal_type(comb(id, id)) # N: Revealed type is "def [S] (S`2) -> S`2" +reveal_type(comb(id, id)) # N: Revealed type is "def [T] (T`1) -> T`1" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericNonLinear] @@ -2849,7 +2849,7 @@ def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]: def id(__x: U) -> U: ... fs = [id, id, id] -reveal_type(mix(fs)) # N: Revealed type is "def [T] (T`4) -> builtins.list[T`4]" +reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`3) -> builtins.list[S`3]" reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`5) -> builtins.list[S`5]" [builtins fixtures/list.pyi] @@ -2867,7 +2867,7 @@ def dec2(f: Callable[[T, U], S]) -> Callable[[U], Callable[[T], S]]: ... def test1(x: V) -> V: ... def test2(x: V, y: V) -> V: ... -reveal_type(dec1(test1)) # N: Revealed type is "def () -> def [S] (S`2) -> S`2" +reveal_type(dec1(test1)) # N: Revealed type is "def () -> def [T] (T`1) -> T`1" # TODO: support this situation reveal_type(dec2(test2)) # N: Revealed type is "def (builtins.object) -> def (builtins.object) -> builtins.object" [builtins fixtures/paramspec.pyi] @@ -2886,7 +2886,7 @@ def dec(f: A[S, T]) -> B[S, T]: ... def id(x: U) -> U: ... -reveal_type(dec(id)) # N: Revealed type is "def [T] (T`2) -> builtins.list[T`2]" +reveal_type(dec(id)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericProtocol] @@ -2933,5 +2933,5 @@ def dec(f: Callable[[T], S], g: Callable[[T], U]) -> Callable[[T], Tuple[S, U]]: def id(x: V) -> V: ... -reveal_type(dec(id, id)) # N: Revealed type is "def [S] (S`2) -> Tuple[S`2, S`2]" +reveal_type(dec(id, id)) # N: Revealed type is "def [T] (T`1) -> Tuple[T`1, T`1]" [builtins fixtures/tuple.pyi] From 4aca3baf6cced2960bc287f7b5d70e7a146cb50d Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 6 Jun 2023 00:17:59 +0100 Subject: [PATCH 13/16] Special-case a corner case --- mypy/constraints.py | 14 +++++++++++++- test-data/unit/check-generics.test | 14 ++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 395cdbf2972a..da9113d9ba79 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -887,7 +887,19 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: if param_spec is None: # FIX verify argument counts # TODO: Erase template variables if it is generic? - if cactual.variables and cactual.param_spec() is None: + if ( + cactual.variables + and cactual.param_spec() is None + # Technically, the correct inferred type for application of + # Callable[..., T] -> Callable[..., T] (with literal ellipsis), to a generic + # like U -> U, should be Callable[..., Any], but if U is a self type, we can + # allow it to leak, to be later bound to self. A bunch of existing code depends + # on this old behaviour. + and not ( + any(tv.id.raw_id == 0 for tv in cactual.variables) + and template.is_ellipsis_args + ) + ): # If actual is generic, unify it with template. Note: this is # not an ideal solution (which would be adding the generic variables # to the constraint inference set), but it's a good first approximation, diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 80bdf43f92a8..012eb86dc878 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2935,3 +2935,17 @@ def id(x: V) -> V: reveal_type(dec(id, id)) # N: Revealed type is "def [T] (T`1) -> Tuple[T`1, T`1]" [builtins fixtures/tuple.pyi] + +[case testInferenceAgainstGenericEllipsisSelfSpecialCase] +from typing import Self, Callable, TypeVar + +T = TypeVar("T") +S = TypeVar("S") +def dec(f: Callable[..., T]) -> Callable[..., T]: ... + +class C: + @dec + def test(self) -> Self: ... + +c: C +reveal_type(c.test()) # N: Revealed type is "__main__.C" From 42cc4cff09550a1b136eff080e119e0453087590 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 6 Jun 2023 22:54:09 +0100 Subject: [PATCH 14/16] Make self-type special-casing wider --- mypy/constraints.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index da9113d9ba79..d4b663353ae2 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -890,15 +890,12 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: if ( cactual.variables and cactual.param_spec() is None - # Technically, the correct inferred type for application of + # Technically, the correct inferred type for application of e.g. # Callable[..., T] -> Callable[..., T] (with literal ellipsis), to a generic - # like U -> U, should be Callable[..., Any], but if U is a self type, we can - # allow it to leak, to be later bound to self. A bunch of existing code depends - # on this old behaviour. - and not ( - any(tv.id.raw_id == 0 for tv in cactual.variables) - and template.is_ellipsis_args - ) + # like U -> U, should be Callable[..., Any], but if U is a self-type, we can + # allow it to leak, to be later bound to self. A bunch of existing code + # depends on this old behaviour. + and not any(tv.id.raw_id == 0 for tv in cactual.variables) ): # If actual is generic, unify it with template. Note: this is # not an ideal solution (which would be adding the generic variables From 47db8591a984c67e33f83438dae34ac24b27029b Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 8 Jun 2023 23:58:54 +0100 Subject: [PATCH 15/16] Temporary flag to hide new inference --- mypy/checkexpr.py | 3 ++- mypy/constraints.py | 3 ++- mypy/main.py | 5 +++++ mypy/options.py | 2 ++ mypy/typestate.py | 6 +++++- test-data/unit/check-generics.test | 18 ++++++++++++++---- .../unit/check-parameter-specification.test | 2 ++ test-data/unit/check-plugin-attrs.test | 1 + test-data/unit/pythoneval.test | 14 +++++++++----- 9 files changed, 42 insertions(+), 12 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 36d800830d84..c60acbb933a1 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -309,6 +309,7 @@ def __init__( # on whether current expression is a callee, to give better error messages # related to type context. self.is_callee = False + type_state.infer_polymorphic = self.chk.options.new_type_inference def reset(self) -> None: self.resolved_type = {} @@ -1801,7 +1802,7 @@ def infer_function_type_arguments( elif not first_arg or not is_subtype(self.named_type("builtins.str"), first_arg): self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context) - if any( + if self.chk.options.new_type_inference and any( a is None or isinstance(get_proper_type(a), UninhabitedType) or set(get_type_vars(a)) & set(callee_type.variables) diff --git a/mypy/constraints.py b/mypy/constraints.py index d4b663353ae2..803b9819be6f 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -888,7 +888,8 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # FIX verify argument counts # TODO: Erase template variables if it is generic? if ( - cactual.variables + type_state.infer_polymorphic + and cactual.variables and cactual.param_spec() is None # Technically, the correct inferred type for application of e.g. # Callable[..., T] -> Callable[..., T] (with literal ellipsis), to a generic diff --git a/mypy/main.py b/mypy/main.py index 81a0a045745b..b60c5b2a6bba 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -983,6 +983,11 @@ def add_invertible_flag( dest="custom_typing_module", help="Use a custom typing module", ) + internals_group.add_argument( + "--new-type-inference", + action="store_true", + help="Enable new experimental type inference algorithm", + ) internals_group.add_argument( "--disable-recursive-aliases", action="store_true", diff --git a/mypy/options.py b/mypy/options.py index 45591597ba69..628765f2a201 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -344,6 +344,8 @@ def __init__(self) -> None: # skip most errors after this many messages have been reported. # -1 means unlimited. self.many_errors_threshold = defaults.MANY_ERRORS_THRESHOLD + # Enable new experimental type inference algorithm. + self.new_type_inference = False # Disable recursive type aliases (currently experimental) self.disable_recursive_aliases = False # Deprecated reverse version of the above, do not use. diff --git a/mypy/typestate.py b/mypy/typestate.py index 9f65481e5e94..ff5933af5928 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -93,6 +93,9 @@ class TypeState: inferring: Final[list[tuple[Type, Type]]] # Whether to use joins or unions when solving constraints, see checkexpr.py for details. infer_unions: bool + # Whether to use new type inference algorithm that can infer polymorphic types. + # This is temporary and will be removed soon when new algorithm is more polished. + infer_polymorphic: bool # N.B: We do all of the accesses to these properties through # TypeState, instead of making these classmethods and accessing @@ -110,6 +113,7 @@ def __init__(self) -> None: self._assuming_proper = [] self.inferring = [] self.infer_unions = False + self.infer_polymorphic = False def is_assumed_subtype(self, left: Type, right: Type) -> bool: for l, r in reversed(self._assuming): @@ -311,7 +315,7 @@ def add_all_protocol_deps(self, deps: dict[str, set[str]]) -> None: def reset_global_state() -> None: """Reset most existing global state. - Currently most of it is in this module. Few exceptions are strict optional status and + Currently most of it is in this module. Few exceptions are strict optional status and functools.lru_cache. """ type_state.reset_all_subtype_caches() diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 012eb86dc878..826ff30431c2 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2735,6 +2735,7 @@ reveal_type(dict2) # N: Revealed type is "builtins.dict[Any, __main__.B]" [builtins fixtures/dict.pyi] [case testInferenceAgainstGenericCallable] +# flags: --new-type-inference from typing import TypeVar, Callable, List X = TypeVar('X') @@ -2752,6 +2753,7 @@ reveal_type(bar(id)) # N: Revealed type is "builtins.list[builtins.int]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableNoLeak] +# flags: --new-type-inference from typing import TypeVar, Callable T = TypeVar('T') @@ -2767,6 +2769,7 @@ reveal_type(f(tpl)) # N: Revealed type is "Any" [out] [case testInferenceAgainstGenericCallableChain] +# flags: --new-type-inference from typing import TypeVar, Callable, List X = TypeVar('X') @@ -2779,6 +2782,7 @@ reveal_type(chain(id, id)) # N: Revealed type is "def (builtins.int) -> builtin [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGeneric] +# flags: --new-type-inference from typing import TypeVar, Callable, List S = TypeVar('S') @@ -2793,6 +2797,7 @@ reveal_type(dec(id)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1] [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericReverse] +# flags: --new-type-inference from typing import TypeVar, Callable, List S = TypeVar('S') @@ -2807,6 +2812,7 @@ reveal_type(dec(id)) # N: Revealed type is "def [T] (builtins.list[T`2]) -> T`2 [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericArg] +# flags: --new-type-inference from typing import TypeVar, Callable, List S = TypeVar('S') @@ -2821,6 +2827,7 @@ reveal_type(dec(test)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S` [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericChain] +# flags: --new-type-inference from typing import TypeVar, Callable, List S = TypeVar('S') @@ -2834,6 +2841,7 @@ reveal_type(comb(id, id)) # N: Revealed type is "def [T] (T`1) -> T`1" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericNonLinear] +# flags: --new-type-inference from typing import TypeVar, Callable, List S = TypeVar('S') @@ -2854,6 +2862,7 @@ reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`5) -> builtins [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCurry] +# flags: --new-type-inference from typing import Callable, List, TypeVar S = TypeVar("S") @@ -2873,6 +2882,7 @@ reveal_type(dec2(test2)) # N: Revealed type is "def (builtins.object) -> def (b [builtins fixtures/paramspec.pyi] [case testInferenceAgainstGenericCallableGenericAlias] +# flags: --new-type-inference from typing import TypeVar, Callable, List S = TypeVar('S') @@ -2890,7 +2900,7 @@ reveal_type(dec(id)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1] [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericProtocol] -# flags: --strict-optional +# flags: --strict-optional --new-type-inference from typing import TypeVar, Protocol, Generic, Optional T = TypeVar('T') @@ -2906,7 +2916,7 @@ reveal_type(lift(g)) # N: Revealed type is "def [T] (Union[T`1, None]) -> Union [builtins fixtures/list.pyi] [case testInferenceAgainstGenericSplitOrder] -# flags: --strict-optional +# flags: --strict-optional --new-type-inference from typing import TypeVar, Callable, List S = TypeVar('S') @@ -2921,7 +2931,7 @@ reveal_type(dec(id, id)) # N: Revealed type is "def (builtins.int) -> builtins. [builtins fixtures/list.pyi] [case testInferenceAgainstGenericSplitOrderGeneric] -# flags: --strict-optional +# flags: --strict-optional --new-type-inference from typing import TypeVar, Callable, Tuple S = TypeVar('S') @@ -2937,10 +2947,10 @@ reveal_type(dec(id, id)) # N: Revealed type is "def [T] (T`1) -> Tuple[T`1, T`1 [builtins fixtures/tuple.pyi] [case testInferenceAgainstGenericEllipsisSelfSpecialCase] +# flags: --new-type-inference from typing import Self, Callable, TypeVar T = TypeVar("T") -S = TypeVar("S") def dec(f: Callable[..., T]) -> Callable[..., T]: ... class C: diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index eb056b02596d..b712d7a3cb24 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1041,6 +1041,7 @@ reveal_type(jf(1)) # N: Revealed type is "None" [builtins fixtures/paramspec.pyi] [case testGenericsInInferredParamspecReturn] +# flags: --new-type-inference from typing import Callable, TypeVar, Generic from typing_extensions import ParamSpec @@ -1543,6 +1544,7 @@ def f(f: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ... [builtins fixtures/paramspec.pyi] [case testParamSpecDecoratorAppliedToGeneric] +# flags: --new-type-inference from typing import Callable, List, TypeVar from typing_extensions import ParamSpec diff --git a/test-data/unit/check-plugin-attrs.test b/test-data/unit/check-plugin-attrs.test index d496bd5dc30a..5b8c361906a8 100644 --- a/test-data/unit/check-plugin-attrs.test +++ b/test-data/unit/check-plugin-attrs.test @@ -1173,6 +1173,7 @@ class A: [builtins fixtures/bool.pyi] [case testAttrsFactoryBadReturn] +# flags: --new-type-inference import attr def my_factory() -> int: return 7 diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 1e572d33f094..c6a8fcc8a840 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -833,6 +833,7 @@ _program.py:3: error: Dict entry 1 has incompatible type "str": "str"; expected _program.py:5: error: "Dict[str, int]" has no attribute "xyz" [case testDefaultDict] +# flags: --new-type-inference import typing as t from collections import defaultdict @@ -858,11 +859,11 @@ class MyDDict(t.DefaultDict[int,T], t.Generic[T]): MyDDict(dict)['0'] MyDDict(dict)[0] [out] -_program.py:6: error: Argument 1 to "defaultdict" has incompatible type "Type[List[Any]]"; expected "Callable[[], str]" -_program.py:9: error: Invalid index type "str" for "defaultdict[int, str]"; expected type "int" -_program.py:9: error: Incompatible types in assignment (expression has type "int", target has type "str") -_program.py:19: error: Argument 1 to "tst" has incompatible type "defaultdict[str, List[]]"; expected "defaultdict[int, List[]]" -_program.py:23: error: Invalid index type "str" for "MyDDict[Dict[, ]]"; expected type "int" +_program.py:7: error: Argument 1 to "defaultdict" has incompatible type "Type[List[Any]]"; expected "Callable[[], str]" +_program.py:10: error: Invalid index type "str" for "defaultdict[int, str]"; expected type "int" +_program.py:10: error: Incompatible types in assignment (expression has type "int", target has type "str") +_program.py:20: error: Argument 1 to "tst" has incompatible type "defaultdict[str, List[]]"; expected "defaultdict[int, List[]]" +_program.py:24: error: Invalid index type "str" for "MyDDict[Dict[, ]]"; expected type "int" [case testNoSubcriptionOfStdlibCollections] # flags: --python-version 3.6 @@ -2013,6 +2014,7 @@ def call(callback: Callable[[Unpack[Ts]], Any], *args: Unpack[Ts]) -> Any: ... [case testGenericInferenceWithTuple] +# flags: --new-type-inference from typing import TypeVar, Callable, Tuple T = TypeVar("T") @@ -2024,6 +2026,7 @@ x: Tuple[str, ...] = f(tuple) [out] [case testGenericInferenceWithDataclass] +# flags: --new-type-inference from typing import Any, Collection, List from dataclasses import dataclass, field @@ -2036,6 +2039,7 @@ class A: [out] [case testGenericInferenceWithItertools] +# flags: --new-type-inference from typing import TypeVar, Tuple from itertools import groupby K = TypeVar("K") From 66b4567aa36ae9d4429003e4106db1284d1178e4 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 18 Jun 2023 20:28:56 +0100 Subject: [PATCH 16/16] Address CR --- mypy/solve.py | 31 ++++++++++++++++++++++++++---- test-data/unit/check-generics.test | 21 ++++++++++++++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/mypy/solve.py b/mypy/solve.py index a141f7dfc920..6693d66f3479 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -88,7 +88,7 @@ def solve_non_linear( The whole algorithm consists of five steps: * Propagate via linear constraints to get all possible constraints for each variable - * Find dependencies between type variables, group them in SCCs, and sor topologically + * Find dependencies between type variables, group them in SCCs, and sort topologically * Check all SCC are intrinsically linear, we can't solve (express) T <: List[T] * Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC) * Solve constraints iteratively starting from leafs, updating targets after each step. @@ -112,11 +112,18 @@ def solve_non_linear( leafs = raw_batches[0] free_vars = [] for scc in leafs: + # If all constrain targets in this SCC are type variables within the + # same SCC then the only meaningful solution we can express, is that + # each variable is equal to a new free variable. For example if we + # have T <: S, S <: U, we deduce: T = S = U = . if all( isinstance(c.target, TypeVarType) and c.target.id in vars for tv in scc for c in cmap[tv] ): + # For convenience with current type application machinery, we randomly + # choose one of the existing type variables in SCC and designate it as free + # instead of defining a new type variable as a common solution. # TODO: be careful about upper bounds (or values) when introducing free vars. free_vars.append(sorted(scc, key=lambda x: x.raw_id)[0]) @@ -146,7 +153,17 @@ def solve_non_linear( def solve_iteratively( batch: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId] ) -> dict[TypeVarId, Type | None]: - """Solve constraints for type variables sequentially, updating targets after each step.""" + """Solve constraints sequentially, updating constraint targets after each step. + + We solve for type variables that appear in `batch`. If a constraint target is not constant + (i.e. constraint looks like T :> F[S, ...]), we substitute solutions found so far in + the target F[S, ...]. This way we can gradually solve for all variables in the batch taking + one solvable variable at a time (i.e. such a variable that has at least one constant bound). + + Importantly, variables in free_vars are considered constants, so for example if we have just + one initial constraint T <: List[S], we will have two SCCs {T} and {S}, then we first + designate S as free, and therefore T = List[S] is a valid solution for T. + """ solutions = {} relevant_constraints = [] for tv in batch: @@ -293,8 +310,14 @@ def transitive_closure( ) -> tuple[dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]]]: """Find transitive closure for given constraints on type variables. - Transitive closure gives maximal set of lower/upper bounds for each type variable, such - we cannot deduce any further bounds by chaining other existing bounds. + Transitive closure gives maximal set of lower/upper bounds for each type variable, + such that we cannot deduce any further bounds by chaining other existing bounds. + + For example if we have initial constraints [T <: S, S <: U, U <: int], the transitive + closure is given by: + * {} <: T <: {S, U, int} + * {T} <: S <: {U, int} + * {T, S} <: U <: {int} """ # TODO: merge propagate_constraints_for() into this function. # TODO: add secondary constraints here to make the algorithm complete. diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 826ff30431c2..b78fd21d4817 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2734,6 +2734,9 @@ dict2 = {"a": C1(), **{x: C2() for x in dict1}} reveal_type(dict2) # N: Revealed type is "builtins.dict[Any, __main__.B]" [builtins fixtures/dict.pyi] +-- Type inference for generic decorators applied to generic callables +-- ------------------------------------------------------------------ + [case testInferenceAgainstGenericCallable] # flags: --new-type-inference from typing import TypeVar, Callable, List @@ -2794,6 +2797,12 @@ def dec(f: Callable[[S], T]) -> Callable[[S], List[T]]: def id(x: U) -> U: ... reveal_type(dec(id)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1]" + +@dec +def same(x: U) -> U: + ... +reveal_type(same) # N: Revealed type is "def [S] (S`3) -> builtins.list[S`3]" +reveal_type(same(42)) # N: Revealed type is "builtins.list[builtins.int]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericReverse] @@ -2809,6 +2818,12 @@ def dec(f: Callable[[S], List[T]]) -> Callable[[S], T]: def id(x: U) -> U: ... reveal_type(dec(id)) # N: Revealed type is "def [T] (builtins.list[T`2]) -> T`2" + +@dec +def same(x: U) -> U: + ... +reveal_type(same) # N: Revealed type is "def [T] (builtins.list[T`4]) -> T`4" +reveal_type(same([42])) # N: Revealed type is "builtins.int" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericArg] @@ -2824,6 +2839,12 @@ def dec(f: Callable[[S], T]) -> Callable[[S], T]: def test(x: U) -> List[U]: ... reveal_type(dec(test)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1]" + +@dec +def single(x: U) -> List[U]: + ... +reveal_type(single) # N: Revealed type is "def [S] (S`3) -> builtins.list[S`3]" +reveal_type(single(42)) # N: Revealed type is "builtins.list[builtins.int]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericChain]