Skip to content

Commit 14418bc

Browse files
ilevkivskyiIvan Levkivskyi
and
Ivan Levkivskyi
authored
Polymorphic inference: support for parameter specifications and lambdas (#15837)
This is a third follow-up for #15287 (likely there will be just one more PR, for `TypeVarTuple`s, and few less important items I mentioned in the original PR I will leave for more distant future). After all this PR turned out to be larger than I wanted. The problem is that `Concatenate` support for `ParamSpec` was quite broken, and this caused many of my tests fail. So I decided to include some major cleanup in this PR (I tried splitting it into a separate PR but it turned out to be tricky). After all, if one ignores added tests, it is almost net zero line count. The main problems that I encountered are: * First, valid substitutions for a `ParamSpecType` were: another `ParamSpecType`, `Parameters`, and `CallableType` (and also `AnyType` and `UninhabitedType` but those seem to be handled trivially). Having `CallableType` in this list caused various missed cases, bogus `get_proper_type()`s, and was generally counter-intuitive. * Second (and probably bigger) issue is that it is possible to represent `Concatenate` in two different forms: as a prefix for `ParamSpecType` (used mostly for instances), and as separate argument types (used mostly for callables). The problem is that some parts of the code were implicitly relying on it being in one or the other form, while some other code uncontrollably switched between the two. I propose to fix this by introducing some simplifications and rules (some of which I enforce by asserts): * Only valid non-trivial substitutions (and consequently upper/lower bound in constraints) for `ParamSpecType` are `ParamSpecType` and `Parameters`. * When `ParamSpecType` appears in a callable it must have an empty `prefix`. * `Parameters` cannot contain other `Parameters` (and ideally also `ParamSpecType`s) among argument types. * For inference we bring `Concatenate` to common representation (because both callables and instances may appear in the same expression). Using the `ParamSpecType` representation with `prefix` looks significantly simpler (especially in solver). Apart from this actual implementation of polymorphic inference is simple/straightforward, I just handle the additional `ParamSpecType` cases (in addition to `TypeVarType`) for inference, for solver, and for application. I also enabled polymorphic inference for lambda expressions, since they are handled by similar code paths. Some minor comments: * I fixed couple minor bugs uncovered by this PR (see e.g. test case for accidental `TypeVar` id clash). * I switch few tests to `--new-type-inference` because there error messages are slightly different, and so it is easier for me to test global flip to `True` locally. * I may tweak some of the "ground rules" if `mypy_primer` output will be particularly bad. --------- Co-authored-by: Ivan Levkivskyi <[email protected]>
1 parent fda7a46 commit 14418bc

20 files changed

+639
-234
lines changed

mypy/applytype.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
AnyType,
1010
CallableType,
1111
Instance,
12-
Parameters,
1312
ParamSpecType,
1413
PartialType,
1514
TupleType,
@@ -112,9 +111,13 @@ def apply_generic_arguments(
112111
if param_spec is not None:
113112
nt = id_to_type.get(param_spec.id)
114113
if nt is not None:
115-
nt = get_proper_type(nt)
116-
if isinstance(nt, (CallableType, Parameters)):
117-
callable = callable.expand_param_spec(nt)
114+
# ParamSpec expansion is special-cased, so we need to always expand callable
115+
# as a whole, not expanding arguments individually.
116+
callable = expand_type(callable, id_to_type)
117+
assert isinstance(callable, CallableType)
118+
return callable.copy_modified(
119+
variables=[tv for tv in tvars if tv.id not in id_to_type]
120+
)
118121

119122
# Apply arguments to argument types.
120123
var_arg = callable.var_arg()

mypy/checker.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -4280,12 +4280,14 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
42804280
return_type = self.return_types[-1]
42814281
return_type = get_proper_type(return_type)
42824282

4283+
is_lambda = isinstance(self.scope.top_function(), LambdaExpr)
42834284
if isinstance(return_type, UninhabitedType):
4284-
self.fail(message_registry.NO_RETURN_EXPECTED, s)
4285-
return
4285+
# Avoid extra error messages for failed inference in lambdas
4286+
if not is_lambda or not return_type.ambiguous:
4287+
self.fail(message_registry.NO_RETURN_EXPECTED, s)
4288+
return
42864289

42874290
if s.expr:
4288-
is_lambda = isinstance(self.scope.top_function(), LambdaExpr)
42894291
declared_none_return = isinstance(return_type, NoneType)
42904292
declared_any_return = isinstance(return_type, AnyType)
42914293

@@ -7376,6 +7378,11 @@ def visit_erased_type(self, t: ErasedType) -> bool:
73767378
# This can happen inside a lambda.
73777379
return True
73787380

7381+
def visit_type_var(self, t: TypeVarType) -> bool:
7382+
# This is needed to prevent leaking into partial types during
7383+
# multi-step type inference.
7384+
return t.id.is_meta_var()
7385+
73797386

73807387
class SetNothingToAny(TypeTranslator):
73817388
"""Replace all ambiguous <nothing> types with Any (to avoid spurious extra errors)."""

mypy/checkexpr.py

+109-14
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from mypy.checkstrformat import StringFormatterChecker
1818
from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars
1919
from mypy.errors import ErrorWatcher, report_internal_error
20-
from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars
20+
from mypy.expandtype import (
21+
expand_type,
22+
expand_type_by_instance,
23+
freshen_all_functions_type_vars,
24+
freshen_function_type_vars,
25+
)
2126
from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments
2227
from mypy.literals import literal
2328
from mypy.maptype import map_instance_to_supertype
@@ -122,6 +127,7 @@
122127
false_only,
123128
fixup_partial_type,
124129
function_type,
130+
get_all_type_vars,
125131
get_type_vars,
126132
is_literal_type_like,
127133
make_simplified_union,
@@ -145,6 +151,7 @@
145151
LiteralValue,
146152
NoneType,
147153
Overloaded,
154+
Parameters,
148155
ParamSpecFlavor,
149156
ParamSpecType,
150157
PartialType,
@@ -167,6 +174,7 @@
167174
get_proper_types,
168175
has_recursive_types,
169176
is_named_instance,
177+
remove_dups,
170178
split_with_prefix_and_suffix,
171179
)
172180
from mypy.types_utils import (
@@ -1579,6 +1587,16 @@ def check_callable_call(
15791587
lambda i: self.accept(args[i]),
15801588
)
15811589

1590+
# This is tricky: return type may contain its own type variables, like in
1591+
# def [S] (S) -> def [T] (T) -> tuple[S, T], so we need to update their ids
1592+
# to avoid possible id clashes if this call itself appears in a generic
1593+
# function body.
1594+
ret_type = get_proper_type(callee.ret_type)
1595+
if isinstance(ret_type, CallableType) and ret_type.variables:
1596+
fresh_ret_type = freshen_all_functions_type_vars(callee.ret_type)
1597+
freeze_all_type_vars(fresh_ret_type)
1598+
callee = callee.copy_modified(ret_type=fresh_ret_type)
1599+
15821600
if callee.is_generic():
15831601
need_refresh = any(
15841602
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables
@@ -1597,7 +1615,7 @@ def check_callable_call(
15971615
lambda i: self.accept(args[i]),
15981616
)
15991617
callee = self.infer_function_type_arguments(
1600-
callee, args, arg_kinds, formal_to_actual, context
1618+
callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context
16011619
)
16021620
if need_refresh:
16031621
formal_to_actual = map_actuals_to_formals(
@@ -1864,6 +1882,8 @@ def infer_function_type_arguments_using_context(
18641882
# def identity(x: T) -> T: return x
18651883
#
18661884
# expects_literal(identity(3)) # Should type-check
1885+
# TODO: we may want to add similar exception if all arguments are lambdas, since
1886+
# in this case external context is almost everything we have.
18671887
if not is_generic_instance(ctx) and not is_literal_type_like(ctx):
18681888
return callable.copy_modified()
18691889
args = infer_type_arguments(callable.variables, ret_type, erased_ctx)
@@ -1885,7 +1905,9 @@ def infer_function_type_arguments(
18851905
callee_type: CallableType,
18861906
args: list[Expression],
18871907
arg_kinds: list[ArgKind],
1908+
arg_names: Sequence[str | None] | None,
18881909
formal_to_actual: list[list[int]],
1910+
need_refresh: bool,
18891911
context: Context,
18901912
) -> CallableType:
18911913
"""Infer the type arguments for a generic callee type.
@@ -1927,7 +1949,14 @@ def infer_function_type_arguments(
19271949
if 2 in arg_pass_nums:
19281950
# Second pass of type inference.
19291951
(callee_type, inferred_args) = self.infer_function_type_arguments_pass2(
1930-
callee_type, args, arg_kinds, formal_to_actual, inferred_args, context
1952+
callee_type,
1953+
args,
1954+
arg_kinds,
1955+
arg_names,
1956+
formal_to_actual,
1957+
inferred_args,
1958+
need_refresh,
1959+
context,
19311960
)
19321961

19331962
if (
@@ -1953,6 +1982,17 @@ def infer_function_type_arguments(
19531982
or set(get_type_vars(a)) & set(callee_type.variables)
19541983
for a in inferred_args
19551984
):
1985+
if need_refresh:
1986+
# Technically we need to refresh formal_to_actual after *each* inference pass,
1987+
# since each pass can expand ParamSpec or TypeVarTuple. Although such situations
1988+
# are very rare, not doing this can cause crashes.
1989+
formal_to_actual = map_actuals_to_formals(
1990+
arg_kinds,
1991+
arg_names,
1992+
callee_type.arg_kinds,
1993+
callee_type.arg_names,
1994+
lambda a: self.accept(args[a]),
1995+
)
19561996
# If the regular two-phase inference didn't work, try inferring type
19571997
# variables while allowing for polymorphic solutions, i.e. for solutions
19581998
# potentially involving free variables.
@@ -2000,8 +2040,10 @@ def infer_function_type_arguments_pass2(
20002040
callee_type: CallableType,
20012041
args: list[Expression],
20022042
arg_kinds: list[ArgKind],
2043+
arg_names: Sequence[str | None] | None,
20032044
formal_to_actual: list[list[int]],
20042045
old_inferred_args: Sequence[Type | None],
2046+
need_refresh: bool,
20052047
context: Context,
20062048
) -> tuple[CallableType, list[Type | None]]:
20072049
"""Perform second pass of generic function type argument inference.
@@ -2023,6 +2065,14 @@ def infer_function_type_arguments_pass2(
20232065
if isinstance(arg, (NoneType, UninhabitedType)) or has_erased_component(arg):
20242066
inferred_args[i] = None
20252067
callee_type = self.apply_generic_arguments(callee_type, inferred_args, context)
2068+
if need_refresh:
2069+
formal_to_actual = map_actuals_to_formals(
2070+
arg_kinds,
2071+
arg_names,
2072+
callee_type.arg_kinds,
2073+
callee_type.arg_names,
2074+
lambda a: self.accept(args[a]),
2075+
)
20262076

20272077
arg_types = self.infer_arg_types_in_context(callee_type, args, arg_kinds, formal_to_actual)
20282078

@@ -4735,8 +4785,22 @@ def infer_lambda_type_using_context(
47354785
# they must be considered as indeterminate. We use ErasedType since it
47364786
# does not affect type inference results (it is for purposes like this
47374787
# only).
4738-
callable_ctx = get_proper_type(replace_meta_vars(ctx, ErasedType()))
4739-
assert isinstance(callable_ctx, CallableType)
4788+
if self.chk.options.new_type_inference:
4789+
# With new type inference we can preserve argument types even if they
4790+
# are generic, since new inference algorithm can handle constraints
4791+
# like S <: T (we still erase return type since it's ultimately unknown).
4792+
extra_vars = []
4793+
for arg in ctx.arg_types:
4794+
meta_vars = [tv for tv in get_all_type_vars(arg) if tv.id.is_meta_var()]
4795+
extra_vars.extend([tv for tv in meta_vars if tv not in extra_vars])
4796+
callable_ctx = ctx.copy_modified(
4797+
ret_type=replace_meta_vars(ctx.ret_type, ErasedType()),
4798+
variables=list(ctx.variables) + extra_vars,
4799+
)
4800+
else:
4801+
erased_ctx = replace_meta_vars(ctx, ErasedType())
4802+
assert isinstance(erased_ctx, ProperType) and isinstance(erased_ctx, CallableType)
4803+
callable_ctx = erased_ctx
47404804

47414805
# The callable_ctx may have a fallback of builtins.type if the context
47424806
# is a constructor -- but this fallback doesn't make sense for lambdas.
@@ -5693,18 +5757,28 @@ def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None:
56935757
self.bound_tvars: set[TypeVarLikeType] = set()
56945758
self.seen_aliases: set[TypeInfo] = set()
56955759

5696-
def visit_callable_type(self, t: CallableType) -> Type:
5697-
found_vars = set()
5760+
def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]:
5761+
found_vars = []
56985762
for arg in t.arg_types:
5699-
found_vars |= set(get_type_vars(arg)) & self.poly_tvars
5763+
for tv in get_all_type_vars(arg):
5764+
if isinstance(tv, ParamSpecType):
5765+
normalized: TypeVarLikeType = tv.copy_modified(
5766+
flavor=ParamSpecFlavor.BARE, prefix=Parameters([], [], [])
5767+
)
5768+
else:
5769+
normalized = tv
5770+
if normalized in self.poly_tvars and normalized not in self.bound_tvars:
5771+
found_vars.append(normalized)
5772+
return remove_dups(found_vars)
57005773

5701-
found_vars -= self.bound_tvars
5702-
self.bound_tvars |= found_vars
5774+
def visit_callable_type(self, t: CallableType) -> Type:
5775+
found_vars = self.collect_vars(t)
5776+
self.bound_tvars |= set(found_vars)
57035777
result = super().visit_callable_type(t)
5704-
self.bound_tvars -= found_vars
5778+
self.bound_tvars -= set(found_vars)
57055779

57065780
assert isinstance(result, ProperType) and isinstance(result, CallableType)
5707-
result.variables = list(result.variables) + list(found_vars)
5781+
result.variables = list(result.variables) + found_vars
57085782
return result
57095783

57105784
def visit_type_var(self, t: TypeVarType) -> Type:
@@ -5713,8 +5787,9 @@ def visit_type_var(self, t: TypeVarType) -> Type:
57135787
return super().visit_type_var(t)
57145788

57155789
def visit_param_spec(self, t: ParamSpecType) -> Type:
5716-
# TODO: Support polymorphic apply for ParamSpec.
5717-
raise PolyTranslationError()
5790+
if t in self.poly_tvars and t not in self.bound_tvars:
5791+
raise PolyTranslationError()
5792+
return super().visit_param_spec(t)
57185793

57195794
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
57205795
# TODO: Support polymorphic apply for TypeVarTuple.
@@ -5730,6 +5805,26 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type:
57305805
raise PolyTranslationError()
57315806

57325807
def visit_instance(self, t: Instance) -> Type:
5808+
if t.type.has_param_spec_type:
5809+
# We need this special-casing to preserve the possibility to store a
5810+
# generic function in an instance type. Things like
5811+
# forall T . Foo[[x: T], T]
5812+
# are not really expressible in current type system, but this looks like
5813+
# a useful feature, so let's keep it.
5814+
param_spec_index = next(
5815+
i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType)
5816+
)
5817+
p = get_proper_type(t.args[param_spec_index])
5818+
if isinstance(p, Parameters):
5819+
found_vars = self.collect_vars(p)
5820+
self.bound_tvars |= set(found_vars)
5821+
new_args = [a.accept(self) for a in t.args]
5822+
self.bound_tvars -= set(found_vars)
5823+
5824+
repl = new_args[param_spec_index]
5825+
assert isinstance(repl, ProperType) and isinstance(repl, Parameters)
5826+
repl.variables = list(repl.variables) + list(found_vars)
5827+
return t.copy_modified(args=new_args)
57335828
# There is the same problem with callback protocols as with aliases
57345829
# (callback protocols are essentially more flexible aliases to callables).
57355830
# Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T].

0 commit comments

Comments
 (0)