Skip to content

Commit b47245a

Browse files
authored
Improve type checking of ParamSpec in calls (#11603)
Check that the correct ParamSpec and flavor are used in `*args` and `**kwargs`. Follow-up to #11594.
1 parent 83e9e3b commit b47245a

File tree

5 files changed

+58
-14
lines changed

5 files changed

+58
-14
lines changed

mypy/argmap.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mypy.maptype import map_instance_to_supertype
66
from mypy.types import (
7-
Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType, get_proper_type
7+
Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType, ParamSpecType, get_proper_type
88
)
99
from mypy import nodes
1010

@@ -191,6 +191,9 @@ def expand_actual_type(self,
191191
else:
192192
self.tuple_index += 1
193193
return actual_type.items[self.tuple_index - 1]
194+
elif isinstance(actual_type, ParamSpecType):
195+
# ParamSpec is valid in *args but it can't be unpacked.
196+
return actual_type
194197
else:
195198
return AnyType(TypeOfAny.from_error)
196199
elif actual_kind == nodes.ARG_STAR2:
@@ -215,6 +218,9 @@ def expand_actual_type(self,
215218
actual_type,
216219
self.context.mapping_type.type,
217220
).args[1]
221+
elif isinstance(actual_type, ParamSpecType):
222+
# ParamSpec is valid in **kwargs but it can't be unpacked.
223+
return actual_type
218224
else:
219225
return AnyType(TypeOfAny.from_error)
220226
else:

mypy/checkexpr.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -1032,24 +1032,24 @@ def check_callable_call(self,
10321032
callee = self.infer_function_type_arguments(
10331033
callee, args, arg_kinds, formal_to_actual, context)
10341034
if need_refresh:
1035-
# Argument kinds etc. may have changed; recalculate actual-to-formal map
1035+
# Argument kinds etc. may have changed due to
1036+
# ParamSpec variables being replaced with an arbitrary
1037+
# number of arguments; recalculate actual-to-formal map
10361038
formal_to_actual = map_actuals_to_formals(
10371039
arg_kinds, arg_names,
10381040
callee.arg_kinds, callee.arg_names,
10391041
lambda i: self.accept(args[i]))
10401042

10411043
param_spec = callee.param_spec()
10421044
if param_spec is not None and arg_kinds == [ARG_STAR, ARG_STAR2]:
1043-
arg1 = get_proper_type(self.accept(args[0]))
1044-
arg2 = get_proper_type(self.accept(args[1]))
1045-
if (is_named_instance(arg1, 'builtins.tuple')
1046-
and is_named_instance(arg2, 'builtins.dict')):
1047-
assert isinstance(arg1, Instance)
1048-
assert isinstance(arg2, Instance)
1049-
if (isinstance(arg1.args[0], ParamSpecType)
1050-
and isinstance(arg2.args[1], ParamSpecType)):
1051-
# TODO: Check ParamSpec ids and flavors
1052-
return callee.ret_type, callee
1045+
arg1 = self.accept(args[0])
1046+
arg2 = self.accept(args[1])
1047+
if (isinstance(arg1, ParamSpecType)
1048+
and isinstance(arg2, ParamSpecType)
1049+
and arg1.flavor == ParamSpecFlavor.ARGS
1050+
and arg2.flavor == ParamSpecFlavor.KWARGS
1051+
and arg1.id == arg2.id == param_spec.id):
1052+
return callee.ret_type, callee
10531053

10541054
arg_types = self.infer_arg_types_in_context(
10551055
callee, args, arg_kinds, formal_to_actual)
@@ -4003,7 +4003,7 @@ def is_valid_var_arg(self, typ: Type) -> bool:
40034003
is_subtype(typ, self.chk.named_generic_type('typing.Iterable',
40044004
[AnyType(TypeOfAny.special_form)])) or
40054005
isinstance(typ, AnyType) or
4006-
(isinstance(typ, ParamSpecType) and typ.flavor == ParamSpecFlavor.ARGS))
4006+
isinstance(typ, ParamSpecType))
40074007

40084008
def is_valid_keyword_var_arg(self, typ: Type) -> bool:
40094009
"""Is a type valid as a **kwargs argument?"""
@@ -4012,7 +4012,7 @@ def is_valid_keyword_var_arg(self, typ: Type) -> bool:
40124012
[self.named_type('builtins.str'), AnyType(TypeOfAny.special_form)])) or
40134013
is_subtype(typ, self.chk.named_generic_type('typing.Mapping',
40144014
[UninhabitedType(), UninhabitedType()])) or
4015-
(isinstance(typ, ParamSpecType) and typ.flavor == ParamSpecFlavor.KWARGS)
4015+
isinstance(typ, ParamSpecType)
40164016
)
40174017
if self.chk.options.python_version[0] < 3:
40184018
ret = ret or is_subtype(typ, self.chk.named_generic_type('typing.Mapping',

mypy/expandtype.py

+2
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
106106
# Return copy of instance with type erasure flag on.
107107
return Instance(inst.type, inst.args, line=inst.line,
108108
column=inst.column, erased=True)
109+
elif isinstance(repl, ParamSpecType):
110+
return repl.with_flavor(t.flavor)
109111
else:
110112
return repl
111113

mypy/types.py

+4
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,10 @@ def new_unification_variable(old: 'ParamSpecType') -> 'ParamSpecType':
491491
return ParamSpecType(old.name, old.fullname, new_id, old.flavor, old.upper_bound,
492492
line=old.line, column=old.column)
493493

494+
def with_flavor(self, flavor: int) -> 'ParamSpecType':
495+
return ParamSpecType(self.name, self.fullname, self.id, flavor,
496+
upper_bound=self.upper_bound)
497+
494498
def accept(self, visitor: 'TypeVisitor[T]') -> T:
495499
return visitor.visit_param_spec(self)
496500

test-data/unit/check-parameter-specification.test

+32
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,35 @@ reveal_type(register(lambda x: f(x), x=1)) # N: Revealed type is "def (x: Any)"
346346
register(lambda x: f(x)) # E: Missing positional argument "x" in call to "register"
347347
register(lambda x: f(x), y=1) # E: Unexpected keyword argument "y" for "register"
348348
[builtins fixtures/dict.pyi]
349+
350+
[case testParamSpecInvalidCalls]
351+
from typing import Callable, Generic
352+
from typing_extensions import ParamSpec
353+
354+
P = ParamSpec('P')
355+
P2 = ParamSpec('P2')
356+
357+
class C(Generic[P, P2]):
358+
def m1(self, *args: P.args, **kwargs: P.kwargs) -> None:
359+
self.m1(*args, **kwargs)
360+
self.m2(*args, **kwargs) # E: Argument 1 to "m2" of "C" has incompatible type "*P.args"; expected "P2.args" \
361+
# E: Argument 2 to "m2" of "C" has incompatible type "**P.kwargs"; expected "P2.kwargs"
362+
self.m1(*kwargs, **args) # E: Argument 1 to "m1" of "C" has incompatible type "*P.kwargs"; expected "P.args" \
363+
# E: Argument 2 to "m1" of "C" has incompatible type "**P.args"; expected "P.kwargs"
364+
self.m3(*args, **kwargs) # E: Argument 1 to "m3" of "C" has incompatible type "*P.args"; expected "int" \
365+
# E: Argument 2 to "m3" of "C" has incompatible type "**P.kwargs"; expected "int"
366+
self.m4(*args, **kwargs) # E: Argument 1 to "m4" of "C" has incompatible type "*P.args"; expected "int" \
367+
# E: Argument 2 to "m4" of "C" has incompatible type "**P.kwargs"; expected "int"
368+
369+
self.m1(*args, **args) # E: Argument 2 to "m1" of "C" has incompatible type "**P.args"; expected "P.kwargs"
370+
self.m1(*kwargs, **kwargs) # E: Argument 1 to "m1" of "C" has incompatible type "*P.kwargs"; expected "P.args"
371+
372+
def m2(self, *args: P2.args, **kwargs: P2.kwargs) -> None:
373+
pass
374+
375+
def m3(self, *args: int, **kwargs: int) -> None:
376+
pass
377+
378+
def m4(self, x: int) -> None:
379+
pass
380+
[builtins fixtures/dict.pyi]

0 commit comments

Comments
 (0)