Skip to content

Commit 0873230

Browse files
authored
Foundations for non-linear solver and polymorphic application (#15287)
Fixes #1317 Fixes #5738 Fixes #12919 (also fixes a `FIX` comment that is more than 10 years old according to git blame) Note: although this PR fixes most typical use-cases for type inference against generic functions, it is intentionally incomplete, and it is made in a way to limit implications to small scope. This PR has essentially three components (better infer, better solve, better apply - all three are needed for this MVP to work): * A "tiny" change to `constraints.py`: if the actual function is generic, we unify it with template before inferring constraints. This prevents leaking generic type variables of actual in the solutions (which makes no sense), but also introduces new kind of constraints `T <: F[S]`, where type variables we solve for appear in target type. These are much harder to solve, but also it is a great opportunity to play with them to prepare for single bin inference (if we will switch to it in some form later). Note unifying is not the best solution, but a good first approximation (see below on what is the best solution). * New more sophisticated constraint solver in `solve.py`. The full algorithm is outlined in the docstring for `solve_non_linear()`. It looks like it should be able to solve arbitrary constraints that don't (indirectly) contain "F-bounded" things like `T <: list[T]`. Very short the idea is to compute transitive closure, then organize constraints by topologically sorted SCCs. * Polymorphic type argument application in `checkexpr.py`. In cases where solver identifies there are free variables (e.g. we have just one constraint `S <: list[T]`, so `T` is free, and solution for `S` is `list[T]`) it will apply the solutions while creating new generic functions. For example, if we have a function `def [S, T] (fn: Callable[[S], T]) -> Callable[[S], T]` applied to a function `def [U] (x: U) -> U`, this will result in `def [T] (T) -> T` as the return. I want to put here some thoughts on the last ingredient, since it may be mysterious, but now it seems to me it is actually a very well defined procedure. The key point here is thinking about generic functions as about infinite intersections or infinite overloads. Now reducing these infinite overloads/intersections to finite ones it is easy to understand what is actually going on. For example, imagine we live in a world with just two types `int` and `str`. Now we have two functions: ```python T = TypeVar("T") S = TypeVar("S") U = TypeVar("U") def dec(fn: Callable[[T], S]) -> Callable[[T], S]: ... def id(x: U) -> U: ... ``` the first one can be seen as overload over ``` ((int) -> int) -> ((int) -> int) # 1 ((int) -> str) -> ((int) -> str) # 2 ((str) -> int) -> ((str) -> int) # 3 ((str) -> str) -> ((str) -> str) # 4 ``` and second as an overload over ``` (int) -> int (str) -> str ``` Now what happens when I apply `dec(id)`? We need to choose an overload that matches the argument (this is what we call type inference), but here is a trick, in this case two overloads of `dec` match the argument type. So (and btw I think we are missing this for real overloads) we construct a new overload that returns intersection of matching overloads `# 1` and `# 4`. So if we generalize this intuition to the general case, the inference is selection of an (infinite) parametrized subset among the bigger parameterized set of intersecting types. The only question is whether resulting infinite intersection is representable in our type system. For example `forall T. dict[T, T]` can make sense but is not representable, while `forall T. (T) -> T` is a well defined type. And finally, there is a very easy way to find whether a type is representable or not, we are already doing this during semantic analyzis. I use the same logic (that I used to view as ad-hoc because of lack of good syntax for callables) to bind type variables in the inferred type. OK, so here is the list of missing features, and some comments on them: 1. Instead of unifying the actual with template we should include actual's variables in variable set we solve for, as explained in #5738 (comment). Note however, this will work only together with the next item 2. We need to (iteratively) infer secondary constraints after linear propagation, e.g. `Sequence[T] <: S <: Sequence[U] => T <: U` 3. Support `ParamSpec` (and probably `TypeVarTuple`). Current support for applying callables with `ParamSpec` to generics is hacky, and kind of dead-end. Although `(Callable[P, T]) -> Callable[P, List[T]]` works when applied to `id`, even a slight variation like `(Callable[P, List[T]]) -> Callable[P, T]` fails. I think it needs to be re-worked in the framework I propose (the tests I added are just to be sure I don't break existing code) 4. Support actual types that are generic in type variables with upper bounds or values (likely we just need to be careful when propagating constraints and choosing free variable within an SCC). 5. Add backtracking for upper/lower bound choice. In general, in the current "Hanoi Tower" inference scheme it is very hard to backtrack, but in in this specific choice in the new solver, it should be totally possible to switch from lower to upper bound on a previous step, if we found no solution (or `<nothing>`/`object`). 6. After we polish it, we can use the new solver in more situations, e.g. for return type context, and for unification during callable subtyping. 7. Long term we may want to allow instances to bind type variables, at least for things like `LRUCache[[x: T], T]`. Btw note that I apply force expansion to type aliases and callback protocols. Since I can't transform e.g. `A = Callable[[T], T]` into a generic callable without getting proper type. 8. We need to figure out a solution for scenarios where non-linear targets with free variables and constant targets mix without secondary constraints, like `T <: List[int], T <: List[S]`. I am planning to address at least majority of the above items, but I think we should move slowly, since in my experience type inference is really fragile topic with hard to predict long reaching consequences. Please play with this PR if you want to and have time, and please suggest tests to add.
1 parent 91b6740 commit 0873230

17 files changed

+998
-193
lines changed

mypy/build.py

+2-104
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,12 @@
3131
Callable,
3232
ClassVar,
3333
Dict,
34-
Iterable,
3534
Iterator,
3635
Mapping,
3736
NamedTuple,
3837
NoReturn,
3938
Sequence,
4039
TextIO,
41-
TypeVar,
4240
)
4341
from typing_extensions import Final, TypeAlias as _TypeAlias
4442

@@ -47,6 +45,7 @@
4745
import mypy.semanal_main
4846
from mypy.checker import TypeChecker
4947
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
48+
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
5049
from mypy.indirection import TypeIndirectionVisitor
5150
from mypy.messages import MessageBuilder
5251
from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable, TypeInfo
@@ -3466,15 +3465,8 @@ def sorted_components(
34663465
edges = {id: deps_filtered(graph, vertices, id, pri_max) for id in vertices}
34673466
sccs = list(strongly_connected_components(vertices, edges))
34683467
# Topsort.
3469-
sccsmap = {id: frozenset(scc) for scc in sccs for id in scc}
3470-
data: dict[AbstractSet[str], set[AbstractSet[str]]] = {}
3471-
for scc in sccs:
3472-
deps: set[AbstractSet[str]] = set()
3473-
for id in scc:
3474-
deps.update(sccsmap[x] for x in deps_filtered(graph, vertices, id, pri_max))
3475-
data[frozenset(scc)] = deps
34763468
res = []
3477-
for ready in topsort(data):
3469+
for ready in topsort(prepare_sccs(sccs, edges)):
34783470
# Sort the sets in ready by reversed smallest State.order. Examples:
34793471
#
34803472
# - If ready is [{x}, {y}], x.order == 1, y.order == 2, we get
@@ -3499,100 +3491,6 @@ def deps_filtered(graph: Graph, vertices: AbstractSet[str], id: str, pri_max: in
34993491
]
35003492

35013493

3502-
def strongly_connected_components(
3503-
vertices: AbstractSet[str], edges: dict[str, list[str]]
3504-
) -> Iterator[set[str]]:
3505-
"""Compute Strongly Connected Components of a directed graph.
3506-
3507-
Args:
3508-
vertices: the labels for the vertices
3509-
edges: for each vertex, gives the target vertices of its outgoing edges
3510-
3511-
Returns:
3512-
An iterator yielding strongly connected components, each
3513-
represented as a set of vertices. Each input vertex will occur
3514-
exactly once; vertices not part of a SCC are returned as
3515-
singleton sets.
3516-
3517-
From https://code.activestate.com/recipes/578507/.
3518-
"""
3519-
identified: set[str] = set()
3520-
stack: list[str] = []
3521-
index: dict[str, int] = {}
3522-
boundaries: list[int] = []
3523-
3524-
def dfs(v: str) -> Iterator[set[str]]:
3525-
index[v] = len(stack)
3526-
stack.append(v)
3527-
boundaries.append(index[v])
3528-
3529-
for w in edges[v]:
3530-
if w not in index:
3531-
yield from dfs(w)
3532-
elif w not in identified:
3533-
while index[w] < boundaries[-1]:
3534-
boundaries.pop()
3535-
3536-
if boundaries[-1] == index[v]:
3537-
boundaries.pop()
3538-
scc = set(stack[index[v] :])
3539-
del stack[index[v] :]
3540-
identified.update(scc)
3541-
yield scc
3542-
3543-
for v in vertices:
3544-
if v not in index:
3545-
yield from dfs(v)
3546-
3547-
3548-
T = TypeVar("T")
3549-
3550-
3551-
def topsort(data: dict[T, set[T]]) -> Iterable[set[T]]:
3552-
"""Topological sort.
3553-
3554-
Args:
3555-
data: A map from vertices to all vertices that it has an edge
3556-
connecting it to. NOTE: This data structure
3557-
is modified in place -- for normalization purposes,
3558-
self-dependencies are removed and entries representing
3559-
orphans are added.
3560-
3561-
Returns:
3562-
An iterator yielding sets of vertices that have an equivalent
3563-
ordering.
3564-
3565-
Example:
3566-
Suppose the input has the following structure:
3567-
3568-
{A: {B, C}, B: {D}, C: {D}}
3569-
3570-
This is normalized to:
3571-
3572-
{A: {B, C}, B: {D}, C: {D}, D: {}}
3573-
3574-
The algorithm will yield the following values:
3575-
3576-
{D}
3577-
{B, C}
3578-
{A}
3579-
3580-
From https://code.activestate.com/recipes/577413/.
3581-
"""
3582-
# TODO: Use a faster algorithm?
3583-
for k, v in data.items():
3584-
v.discard(k) # Ignore self dependencies.
3585-
for item in set.union(*data.values()) - set(data.keys()):
3586-
data[item] = set()
3587-
while True:
3588-
ready = {item for item, dep in data.items() if not dep}
3589-
if not ready:
3590-
break
3591-
yield ready
3592-
data = {item: (dep - ready) for item, dep in data.items() if item not in ready}
3593-
assert not data, f"A cyclic dependency exists amongst {data!r}"
3594-
3595-
35963494
def missing_stubs_file(cache_dir: str) -> str:
35973495
return os.path.join(cache_dir, "missing_stubs")
35983496

mypy/checkexpr.py

+143-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import mypy.errorcodes as codes
1313
from mypy import applytype, erasetype, join, message_registry, nodes, operators, types
1414
from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals
15-
from mypy.checkmember import analyze_member_access, type_object_type
15+
from mypy.checkmember import analyze_member_access, freeze_all_type_vars, type_object_type
1616
from mypy.checkstrformat import StringFormatterChecker
1717
from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars
1818
from mypy.errors import ErrorWatcher, report_internal_error
@@ -98,8 +98,15 @@
9898
)
9999
from mypy.semanal_enum import ENUM_BASES
100100
from mypy.state import state
101-
from mypy.subtypes import is_equivalent, is_same_type, is_subtype, non_method_protocol_members
101+
from mypy.subtypes import (
102+
find_member,
103+
is_equivalent,
104+
is_same_type,
105+
is_subtype,
106+
non_method_protocol_members,
107+
)
102108
from mypy.traverser import has_await_expression
109+
from mypy.type_visitor import TypeTranslator
103110
from mypy.typeanal import (
104111
check_for_explicit_any,
105112
has_any_from_unimported_type,
@@ -114,6 +121,7 @@
114121
false_only,
115122
fixup_partial_type,
116123
function_type,
124+
get_type_vars,
117125
is_literal_type_like,
118126
make_simplified_union,
119127
simple_literal_type,
@@ -146,6 +154,7 @@
146154
TypedDictType,
147155
TypeOfAny,
148156
TypeType,
157+
TypeVarLikeType,
149158
TypeVarTupleType,
150159
TypeVarType,
151160
UninhabitedType,
@@ -300,6 +309,7 @@ def __init__(
300309
# on whether current expression is a callee, to give better error messages
301310
# related to type context.
302311
self.is_callee = False
312+
type_state.infer_polymorphic = self.chk.options.new_type_inference
303313

304314
def reset(self) -> None:
305315
self.resolved_type = {}
@@ -1791,6 +1801,51 @@ def infer_function_type_arguments(
17911801
inferred_args[0] = self.named_type("builtins.str")
17921802
elif not first_arg or not is_subtype(self.named_type("builtins.str"), first_arg):
17931803
self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context)
1804+
1805+
if self.chk.options.new_type_inference and any(
1806+
a is None
1807+
or isinstance(get_proper_type(a), UninhabitedType)
1808+
or set(get_type_vars(a)) & set(callee_type.variables)
1809+
for a in inferred_args
1810+
):
1811+
# If the regular two-phase inference didn't work, try inferring type
1812+
# variables while allowing for polymorphic solutions, i.e. for solutions
1813+
# potentially involving free variables.
1814+
# TODO: support the similar inference for return type context.
1815+
poly_inferred_args = infer_function_type_arguments(
1816+
callee_type,
1817+
arg_types,
1818+
arg_kinds,
1819+
formal_to_actual,
1820+
context=self.argument_infer_context(),
1821+
strict=self.chk.in_checked_function(),
1822+
allow_polymorphic=True,
1823+
)
1824+
for i, pa in enumerate(get_proper_types(poly_inferred_args)):
1825+
if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa):
1826+
# Indicate that free variables should not be applied in the call below.
1827+
poly_inferred_args[i] = None
1828+
poly_callee_type = self.apply_generic_arguments(
1829+
callee_type, poly_inferred_args, context
1830+
)
1831+
yes_vars = poly_callee_type.variables
1832+
no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables}
1833+
if not set(get_type_vars(poly_callee_type)) & no_vars:
1834+
# Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can
1835+
# be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed.
1836+
applied = apply_poly(poly_callee_type, yes_vars)
1837+
if applied is not None and poly_inferred_args != [UninhabitedType()] * len(
1838+
poly_inferred_args
1839+
):
1840+
freeze_all_type_vars(applied)
1841+
return applied
1842+
# If it didn't work, erase free variables as <nothing>, to avoid confusing errors.
1843+
inferred_args = [
1844+
expand_type(a, {v.id: UninhabitedType() for v in callee_type.variables})
1845+
if a is not None
1846+
else None
1847+
for a in inferred_args
1848+
]
17941849
else:
17951850
# In dynamically typed functions use implicit 'Any' types for
17961851
# type variables.
@@ -5393,6 +5448,92 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl
53935448
return c.copy_modified(ret_type=new_ret_type)
53945449

53955450

5451+
def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> Optional[CallableType]:
5452+
"""Make free type variables generic in the type if possible.
5453+
5454+
This will translate the type `tp` while trying to create valid bindings for
5455+
type variables `poly_tvars` while traversing the type. This follows the same rules
5456+
as we do during semantic analysis phase, examples:
5457+
* Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T
5458+
* Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T)
5459+
* List[T] -> None (not possible)
5460+
"""
5461+
try:
5462+
return tp.copy_modified(
5463+
arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types],
5464+
ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)),
5465+
variables=[],
5466+
)
5467+
except PolyTranslationError:
5468+
return None
5469+
5470+
5471+
class PolyTranslationError(Exception):
5472+
pass
5473+
5474+
5475+
class PolyTranslator(TypeTranslator):
5476+
"""Make free type variables generic in the type if possible.
5477+
5478+
See docstring for apply_poly() for details.
5479+
"""
5480+
5481+
def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None:
5482+
self.poly_tvars = set(poly_tvars)
5483+
# This is a simplified version of TypeVarScope used during semantic analysis.
5484+
self.bound_tvars: set[TypeVarLikeType] = set()
5485+
self.seen_aliases: set[TypeInfo] = set()
5486+
5487+
def visit_callable_type(self, t: CallableType) -> Type:
5488+
found_vars = set()
5489+
for arg in t.arg_types:
5490+
found_vars |= set(get_type_vars(arg)) & self.poly_tvars
5491+
5492+
found_vars -= self.bound_tvars
5493+
self.bound_tvars |= found_vars
5494+
result = super().visit_callable_type(t)
5495+
self.bound_tvars -= found_vars
5496+
5497+
assert isinstance(result, ProperType) and isinstance(result, CallableType)
5498+
result.variables = list(result.variables) + list(found_vars)
5499+
return result
5500+
5501+
def visit_type_var(self, t: TypeVarType) -> Type:
5502+
if t in self.poly_tvars and t not in self.bound_tvars:
5503+
raise PolyTranslationError()
5504+
return super().visit_type_var(t)
5505+
5506+
def visit_param_spec(self, t: ParamSpecType) -> Type:
5507+
# TODO: Support polymorphic apply for ParamSpec.
5508+
raise PolyTranslationError()
5509+
5510+
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
5511+
# TODO: Support polymorphic apply for TypeVarTuple.
5512+
raise PolyTranslationError()
5513+
5514+
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
5515+
if not t.args:
5516+
return t.copy_modified()
5517+
if not t.is_recursive:
5518+
return get_proper_type(t).accept(self)
5519+
# We can't handle polymorphic application for recursive generic aliases
5520+
# without risking an infinite recursion, just give up for now.
5521+
raise PolyTranslationError()
5522+
5523+
def visit_instance(self, t: Instance) -> Type:
5524+
# There is the same problem with callback protocols as with aliases
5525+
# (callback protocols are essentially more flexible aliases to callables).
5526+
# Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T].
5527+
if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]:
5528+
if t.type in self.seen_aliases:
5529+
raise PolyTranslationError()
5530+
self.seen_aliases.add(t.type)
5531+
call = find_member("__call__", t, t, is_operator=True)
5532+
assert call is not None
5533+
return call.accept(self)
5534+
return super().visit_instance(t)
5535+
5536+
53965537
class ArgInferSecondPassQuery(types.BoolTypeQuery):
53975538
"""Query whether an argument type should be inferred in the second pass.
53985539

mypy/constraints.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -886,7 +886,30 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
886886
param_spec = template.param_spec()
887887
if param_spec is None:
888888
# FIX verify argument counts
889-
# FIX what if one of the functions is generic
889+
# TODO: Erase template variables if it is generic?
890+
if (
891+
type_state.infer_polymorphic
892+
and cactual.variables
893+
and cactual.param_spec() is None
894+
# Technically, the correct inferred type for application of e.g.
895+
# Callable[..., T] -> Callable[..., T] (with literal ellipsis), to a generic
896+
# like U -> U, should be Callable[..., Any], but if U is a self-type, we can
897+
# allow it to leak, to be later bound to self. A bunch of existing code
898+
# depends on this old behaviour.
899+
and not any(tv.id.raw_id == 0 for tv in cactual.variables)
900+
):
901+
# If actual is generic, unify it with template. Note: this is
902+
# not an ideal solution (which would be adding the generic variables
903+
# to the constraint inference set), but it's a good first approximation,
904+
# and this will prevent leaking these variables in the solutions.
905+
# Note: this may infer constraints like T <: S or T <: List[S]
906+
# that contain variables in the target.
907+
unified = mypy.subtypes.unify_generic_callable(
908+
cactual, template, ignore_return=True
909+
)
910+
if unified is not None:
911+
cactual = unified
912+
res.extend(infer_constraints(cactual, template, neg_op(self.direction)))
890913

891914
# We can't infer constraints from arguments if the template is Callable[..., T]
892915
# (with literal '...').

0 commit comments

Comments
 (0)