diff --git a/misc/proper_plugin.py b/misc/proper_plugin.py index 25dfac131bf8..249ad983266b 100644 --- a/misc/proper_plugin.py +++ b/misc/proper_plugin.py @@ -66,6 +66,7 @@ def is_special_target(right: ProperType) -> bool: if right.type_object().fullname in ( 'mypy.types.UnboundType', 'mypy.types.TypeVarType', + 'mypy.types.ParamSpecType', 'mypy.types.RawExpressionType', 'mypy.types.EllipsisType', 'mypy.types.StarType', diff --git a/mypy/applytype.py b/mypy/applytype.py index 6a5b5bf85cc3..5b803a4aaa0b 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -5,7 +5,7 @@ from mypy.expandtype import expand_type from mypy.types import ( Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, get_proper_types, - TypeVarLikeType, ProperType, ParamSpecType + TypeVarLikeType, ProperType, ParamSpecType, get_proper_type ) from mypy.nodes import Context @@ -18,9 +18,8 @@ def get_target_type( context: Context, skip_unsatisfied: bool ) -> Optional[Type]: - # TODO(PEP612): fix for ParamSpecType if isinstance(tvar, ParamSpecType): - return None + return type assert isinstance(tvar, TypeVarType) values = get_proper_types(tvar.values) if values: @@ -90,6 +89,14 @@ def apply_generic_arguments( if target_type is not None: id_to_type[tvar.id] = target_type + param_spec = callable.param_spec() + if param_spec is not None: + nt = id_to_type.get(param_spec.id) + if nt is not None: + nt = get_proper_type(nt) + if isinstance(nt, CallableType): + callable = callable.expand_param_spec(nt) + # Apply arguments to argument types. arg_types = [expand_type(at, id_to_type) for at in callable.arg_types] diff --git a/mypy/checker.py b/mypy/checker.py index 763176a3e6ae..5b73091365e5 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -37,7 +37,8 @@ UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, is_named_instance, union_items, TypeQuery, LiteralType, is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType, - get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType) + get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType, ParamSpecType +) from mypy.sametypes import is_same_type from mypy.messages import ( MessageBuilder, make_inferred_type_note, append_invariance_notes, pretty_seq, @@ -976,13 +977,15 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) ctx = typ self.fail(message_registry.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, ctx) if typ.arg_kinds[i] == nodes.ARG_STAR: - # builtins.tuple[T] is typing.Tuple[T, ...] - arg_type = self.named_generic_type('builtins.tuple', - [arg_type]) + if not isinstance(arg_type, ParamSpecType): + # builtins.tuple[T] is typing.Tuple[T, ...] + arg_type = self.named_generic_type('builtins.tuple', + [arg_type]) elif typ.arg_kinds[i] == nodes.ARG_STAR2: - arg_type = self.named_generic_type('builtins.dict', - [self.str_type(), - arg_type]) + if not isinstance(arg_type, ParamSpecType): + arg_type = self.named_generic_type('builtins.dict', + [self.str_type(), + arg_type]) item.arguments[i].variable.type = arg_type # Type check initialization expressions. @@ -1883,7 +1886,7 @@ def check_protocol_variance(self, defn: ClassDef) -> None: expected = CONTRAVARIANT else: expected = INVARIANT - if expected != tvar.variance: + if isinstance(tvar, TypeVarType) and expected != tvar.variance: self.msg.bad_proto_variance(tvar.variance, tvar.name, expected, defn) def check_multiple_inheritance(self, typ: TypeInfo) -> None: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index e850744b5c71..987df610ecd0 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -18,7 +18,7 @@ Type, AnyType, CallableType, Overloaded, NoneType, TypeVarType, TupleType, TypedDictType, Instance, ErasedType, UnionType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue, - is_named_instance, FunctionLike, ParamSpecType, + is_named_instance, FunctionLike, ParamSpecType, ParamSpecFlavor, StarType, is_optional, remove_optional, is_generic_instance, get_proper_type, ProperType, get_proper_types, flatten_nested_unions ) @@ -1023,11 +1023,31 @@ def check_callable_call(self, lambda i: self.accept(args[i])) if callee.is_generic(): + need_refresh = any(isinstance(v, ParamSpecType) for v in callee.variables) callee = freshen_function_type_vars(callee) callee = self.infer_function_type_arguments_using_context( callee, context) callee = self.infer_function_type_arguments( callee, args, arg_kinds, formal_to_actual, context) + if need_refresh: + # Argument kinds etc. may have changed; recalculate actual-to-formal map + formal_to_actual = map_actuals_to_formals( + arg_kinds, arg_names, + callee.arg_kinds, callee.arg_names, + lambda i: self.accept(args[i])) + + param_spec = callee.param_spec() + if param_spec is not None and arg_kinds == [ARG_STAR, ARG_STAR2]: + arg1 = get_proper_type(self.accept(args[0])) + arg2 = get_proper_type(self.accept(args[1])) + if (is_named_instance(arg1, 'builtins.tuple') + and is_named_instance(arg2, 'builtins.dict')): + assert isinstance(arg1, Instance) + assert isinstance(arg2, Instance) + if (isinstance(arg1.args[0], ParamSpecType) + and isinstance(arg2.args[1], ParamSpecType)): + # TODO: Check ParamSpec ids and flavors + return callee.ret_type, callee arg_types = self.infer_arg_types_in_context( callee, args, arg_kinds, formal_to_actual) @@ -3979,7 +3999,8 @@ def is_valid_var_arg(self, typ: Type) -> bool: return (isinstance(typ, TupleType) or is_subtype(typ, self.chk.named_generic_type('typing.Iterable', [AnyType(TypeOfAny.special_form)])) or - isinstance(typ, AnyType)) + isinstance(typ, AnyType) or + (isinstance(typ, ParamSpecType) and typ.flavor == ParamSpecFlavor.ARGS)) def is_valid_keyword_var_arg(self, typ: Type) -> bool: """Is a type valid as a **kwargs argument?""" @@ -3987,7 +4008,9 @@ def is_valid_keyword_var_arg(self, typ: Type) -> bool: is_subtype(typ, self.chk.named_generic_type('typing.Mapping', [self.named_type('builtins.str'), AnyType(TypeOfAny.special_form)])) or is_subtype(typ, self.chk.named_generic_type('typing.Mapping', - [UninhabitedType(), UninhabitedType()]))) + [UninhabitedType(), UninhabitedType()])) or + (isinstance(typ, ParamSpecType) and typ.flavor == ParamSpecFlavor.KWARGS) + ) if self.chk.options.python_version[0] < 3: ret = ret or is_subtype(typ, self.chk.named_generic_type('typing.Mapping', [self.named_type('builtins.unicode'), AnyType(TypeOfAny.special_form)])) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 01c6afeb9cec..89c53afdadad 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -6,7 +6,7 @@ from mypy.types import ( Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarLikeType, Overloaded, TypeVarType, UnionType, PartialType, TypeOfAny, LiteralType, - DeletedType, NoneType, TypeType, has_type_vars, get_proper_type, ProperType + DeletedType, NoneType, TypeType, has_type_vars, get_proper_type, ProperType, ParamSpecType ) from mypy.nodes import ( TypeInfo, FuncBase, Var, FuncDef, SymbolNode, SymbolTable, Context, @@ -666,6 +666,9 @@ def f(self: S) -> T: ... selfarg = item.arg_types[0] if subtypes.is_subtype(dispatched_arg_type, erase_typevars(erase_to_bound(selfarg))): new_items.append(item) + elif isinstance(selfarg, ParamSpecType): + # TODO: This is not always right. What's the most reasonable thing to do here? + new_items.append(item) if not new_items: # Choose first item for the message (it may be not very helpful for overloads). msg.incompatible_self_argument(name, dispatched_arg_type, items[0], diff --git a/mypy/constraints.py b/mypy/constraints.py index d7c4e5c2524a..9c1bfb0eba53 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -7,7 +7,8 @@ CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType, - ProperType, get_proper_type, TypeAliasType, is_union_with_any + ProperType, ParamSpecType, get_proper_type, TypeAliasType, is_union_with_any, + callable_with_ellipsis ) from mypy.maptype import map_instance_to_supertype import mypy.subtypes @@ -398,6 +399,10 @@ def visit_type_var(self, template: TypeVarType) -> List[Constraint]: assert False, ("Unexpected TypeVarType in ConstraintBuilderVisitor" " (should have been handled in infer_constraints)") + def visit_param_spec(self, template: ParamSpecType) -> List[Constraint]: + # Can't infer ParamSpecs from component values (only via Callable[P, T]). + return [] + # Non-leaf types def visit_instance(self, template: Instance) -> List[Constraint]: @@ -438,14 +443,16 @@ def visit_instance(self, template: Instance) -> List[Constraint]: # N.B: We use zip instead of indexing because the lengths might have # mismatches during daemon reprocessing. for tvar, mapped_arg, instance_arg in zip(tvars, mapped.args, instance.args): - # The constraints for generic type parameters depend on variance. - # Include constraints from both directions if invariant. - if tvar.variance != CONTRAVARIANT: - res.extend(infer_constraints( - mapped_arg, instance_arg, self.direction)) - if tvar.variance != COVARIANT: - res.extend(infer_constraints( - mapped_arg, instance_arg, neg_op(self.direction))) + # TODO: ParamSpecType + if isinstance(tvar, TypeVarType): + # The constraints for generic type parameters depend on variance. + # Include constraints from both directions if invariant. + if tvar.variance != CONTRAVARIANT: + res.extend(infer_constraints( + mapped_arg, instance_arg, self.direction)) + if tvar.variance != COVARIANT: + res.extend(infer_constraints( + mapped_arg, instance_arg, neg_op(self.direction))) return res elif (self.direction == SUPERTYPE_OF and instance.type.has_base(template.type.fullname)): @@ -454,14 +461,16 @@ def visit_instance(self, template: Instance) -> List[Constraint]: # N.B: We use zip instead of indexing because the lengths might have # mismatches during daemon reprocessing. for tvar, mapped_arg, template_arg in zip(tvars, mapped.args, template.args): - # The constraints for generic type parameters depend on variance. - # Include constraints from both directions if invariant. - if tvar.variance != CONTRAVARIANT: - res.extend(infer_constraints( - template_arg, mapped_arg, self.direction)) - if tvar.variance != COVARIANT: - res.extend(infer_constraints( - template_arg, mapped_arg, neg_op(self.direction))) + # TODO: ParamSpecType + if isinstance(tvar, TypeVarType): + # The constraints for generic type parameters depend on variance. + # Include constraints from both directions if invariant. + if tvar.variance != CONTRAVARIANT: + res.extend(infer_constraints( + template_arg, mapped_arg, self.direction)) + if tvar.variance != COVARIANT: + res.extend(infer_constraints( + template_arg, mapped_arg, neg_op(self.direction))) return res if (template.type.is_protocol and self.direction == SUPERTYPE_OF and # We avoid infinite recursion for structural subtypes by checking @@ -536,32 +545,48 @@ def infer_constraints_from_protocol_members(self, def visit_callable_type(self, template: CallableType) -> List[Constraint]: if isinstance(self.actual, CallableType): - cactual = self.actual - # FIX verify argument counts - # FIX what if one of the functions is generic res: List[Constraint] = [] + cactual = self.actual + param_spec = template.param_spec() + if param_spec is None: + # FIX verify argument counts + # FIX what if one of the functions is generic + + # We can't infer constraints from arguments if the template is Callable[..., T] + # (with literal '...'). + if not template.is_ellipsis_args: + # The lengths should match, but don't crash (it will error elsewhere). + for t, a in zip(template.arg_types, cactual.arg_types): + # Negate direction due to function argument type contravariance. + res.extend(infer_constraints(t, a, neg_op(self.direction))) + else: + # TODO: Direction + # TODO: Deal with arguments that come before param spec ones? + res.append(Constraint(param_spec.id, + SUBTYPE_OF, + cactual.copy_modified(ret_type=NoneType()))) - # We can't infer constraints from arguments if the template is Callable[..., T] (with - # literal '...'). - if not template.is_ellipsis_args: - # The lengths should match, but don't crash (it will error elsewhere). - for t, a in zip(template.arg_types, cactual.arg_types): - # Negate direction due to function argument type contravariance. - res.extend(infer_constraints(t, a, neg_op(self.direction))) template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type if template.type_guard is not None: template_ret_type = template.type_guard if cactual.type_guard is not None: cactual_ret_type = cactual.type_guard + res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction)) return res elif isinstance(self.actual, AnyType): - # FIX what if generic - res = self.infer_against_any(template.arg_types, self.actual) + param_spec = template.param_spec() any_type = AnyType(TypeOfAny.from_another_any, source_any=self.actual) - res.extend(infer_constraints(template.ret_type, any_type, self.direction)) - return res + if param_spec is None: + # FIX what if generic + res = self.infer_against_any(template.arg_types, self.actual) + res.extend(infer_constraints(template.ret_type, any_type, self.direction)) + return res + else: + return [Constraint(param_spec.id, + SUBTYPE_OF, + callable_with_ellipsis(any_type, any_type, template.fallback))] elif isinstance(self.actual, Overloaded): return self.infer_against_overloaded(self.actual, template) elif isinstance(self.actual, TypeType): diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 7a56eceacf5f..8acebbd783d8 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -4,7 +4,7 @@ Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarId, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType, - get_proper_type, TypeAliasType + get_proper_type, TypeAliasType, ParamSpecType ) from mypy.nodes import ARG_STAR, ARG_STAR2 @@ -57,6 +57,9 @@ def visit_instance(self, t: Instance) -> ProperType: def visit_type_var(self, t: TypeVarType) -> ProperType: return AnyType(TypeOfAny.special_form) + def visit_param_spec(self, t: ParamSpecType) -> ProperType: + return AnyType(TypeOfAny.special_form) + def visit_callable_type(self, t: CallableType) -> ProperType: # We must preserve the fallback type for overload resolution to work. any_type = AnyType(TypeOfAny.special_form) @@ -125,6 +128,11 @@ def visit_type_var(self, t: TypeVarType) -> Type: return self.replacement return t + def visit_param_spec(self, t: ParamSpecType) -> Type: + if self.erase_id(t.id): + return self.replacement + return t + def visit_type_alias_type(self, t: TypeAliasType) -> Type: # Type alias target can't contain bound type variables, so # it is safe to just erase the arguments. diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 3bdd95c40cb0..45a645583e5d 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -5,7 +5,7 @@ NoneType, Overloaded, TupleType, TypedDictType, UnionType, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, FunctionLike, TypeVarType, LiteralType, get_proper_type, ProperType, - TypeAliasType, ParamSpecType + TypeAliasType, ParamSpecType, TypeVarLikeType ) @@ -42,10 +42,11 @@ def freshen_function_type_vars(callee: F) -> F: tvmap: Dict[TypeVarId, Type] = {} for v in callee.variables: # TODO(PEP612): fix for ParamSpecType - if isinstance(v, ParamSpecType): - continue - assert isinstance(v, TypeVarType) - tv = TypeVarType.new_unification_variable(v) + if isinstance(v, TypeVarType): + tv: TypeVarLikeType = TypeVarType.new_unification_variable(v) + else: + assert isinstance(v, ParamSpecType) + tv = ParamSpecType.new_unification_variable(v) tvs.append(tv) tvmap[v.id] = tv fresh = cast(CallableType, expand_type(callee, tvmap)).copy_modified(variables=tvs) @@ -98,7 +99,35 @@ def visit_type_var(self, t: TypeVarType) -> Type: else: return repl + def visit_param_spec(self, t: ParamSpecType) -> Type: + repl = get_proper_type(self.variables.get(t.id, t)) + if isinstance(repl, Instance): + inst = repl + # Return copy of instance with type erasure flag on. + return Instance(inst.type, inst.args, line=inst.line, + column=inst.column, erased=True) + else: + return repl + def visit_callable_type(self, t: CallableType) -> Type: + param_spec = t.param_spec() + if param_spec is not None: + repl = get_proper_type(self.variables.get(param_spec.id)) + # If a ParamSpec in a callable type is substituted with a + # callable type, we can't use normal substitution logic, + # since ParamSpec is actually split into two components + # *P.args and **P.kwargs in the original type. Instead, we + # must expand both of them with all the argument types, + # kinds and names in the replacement. The return type in + # the replacement is ignored. + if isinstance(repl, CallableType): + # Substitute *args: P.args, **kwargs: P.kwargs + t = t.expand_param_spec(repl) + # TODO: Substitute remaining arg types + return t.copy_modified(ret_type=t.ret_type.accept(self), + type_guard=(t.type_guard.accept(self) + if t.type_guard is not None else None)) + return t.copy_modified(arg_types=self.expand_types(t.arg_types), ret_type=t.ret_type.accept(self), type_guard=(t.type_guard.accept(self) diff --git a/mypy/fixup.py b/mypy/fixup.py index f43cde2e37c5..da54c40e733f 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -10,7 +10,7 @@ from mypy.types import ( CallableType, Instance, Overloaded, TupleType, TypedDictType, TypeVarType, UnboundType, UnionType, TypeVisitor, LiteralType, - TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny + TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny, ParamSpecType ) from mypy.visitor import NodeVisitor from mypy.lookup import lookup_fully_qualified @@ -121,9 +121,10 @@ def visit_decorator(self, d: Decorator) -> None: def visit_class_def(self, c: ClassDef) -> None: for v in c.type_vars: - for value in v.values: - value.accept(self.type_fixer) - v.upper_bound.accept(self.type_fixer) + if isinstance(v, TypeVarType): + for value in v.values: + value.accept(self.type_fixer) + v.upper_bound.accept(self.type_fixer) def visit_type_var_expr(self, tv: TypeVarExpr) -> None: for value in tv.values: @@ -247,6 +248,9 @@ def visit_type_var(self, tvt: TypeVarType) -> None: if tvt.upper_bound is not None: tvt.upper_bound.accept(self) + def visit_param_spec(self, p: ParamSpecType) -> None: + p.upper_bound.accept(self) + def visit_unbound_type(self, o: UnboundType) -> None: for a in o.args: a.accept(self) diff --git a/mypy/indirection.py b/mypy/indirection.py index 96992285c90f..238f46c8830f 100644 --- a/mypy/indirection.py +++ b/mypy/indirection.py @@ -64,6 +64,9 @@ def visit_deleted_type(self, t: types.DeletedType) -> Set[str]: def visit_type_var(self, t: types.TypeVarType) -> Set[str]: return self._visit(t.values) | self._visit(t.upper_bound) + def visit_param_spec(self, t: types.ParamSpecType) -> Set[str]: + return set() + def visit_instance(self, t: types.Instance) -> Set[str]: out = self._visit(t.args) if t.type: diff --git a/mypy/join.py b/mypy/join.py index 291a934e5943..6bae241a801e 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -7,7 +7,7 @@ Type, AnyType, NoneType, TypeVisitor, Instance, UnboundType, TypeVarType, CallableType, TupleType, TypedDictType, ErasedType, UnionType, FunctionLike, Overloaded, LiteralType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, get_proper_type, - ProperType, get_proper_types, TypeAliasType, PlaceholderType + ProperType, get_proper_types, TypeAliasType, PlaceholderType, ParamSpecType ) from mypy.maptype import map_instance_to_supertype from mypy.subtypes import ( @@ -46,22 +46,28 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType: new_type = AnyType(TypeOfAny.from_another_any, ta_proper) elif isinstance(sa_proper, AnyType): new_type = AnyType(TypeOfAny.from_another_any, sa_proper) - elif type_var.variance == COVARIANT: - new_type = join_types(ta, sa, self) - if len(type_var.values) != 0 and new_type not in type_var.values: - self.seen_instances.pop() - return object_from_instance(t) - if not is_subtype(new_type, type_var.upper_bound): - self.seen_instances.pop() - return object_from_instance(t) - # TODO: contravariant case should use meet but pass seen instances as - # an argument to keep track of recursive checks. - elif type_var.variance in (INVARIANT, CONTRAVARIANT): + elif isinstance(type_var, TypeVarType): + if type_var.variance == COVARIANT: + new_type = join_types(ta, sa, self) + if len(type_var.values) != 0 and new_type not in type_var.values: + self.seen_instances.pop() + return object_from_instance(t) + if not is_subtype(new_type, type_var.upper_bound): + self.seen_instances.pop() + return object_from_instance(t) + # TODO: contravariant case should use meet but pass seen instances as + # an argument to keep track of recursive checks. + elif type_var.variance in (INVARIANT, CONTRAVARIANT): + if not is_equivalent(ta, sa): + self.seen_instances.pop() + return object_from_instance(t) + # If the types are different but equivalent, then an Any is involved + # so using a join in the contravariant case is also OK. + new_type = join_types(ta, sa, self) + else: + # ParamSpec type variables behave the same, independent of variance if not is_equivalent(ta, sa): - self.seen_instances.pop() - return object_from_instance(t) - # If the types are different but equivalent, then an Any is involved - # so using a join in the contravariant case is also OK. + return get_proper_type(type_var.upper_bound) new_type = join_types(ta, sa, self) assert new_type is not None args.append(new_type) @@ -245,6 +251,11 @@ def visit_type_var(self, t: TypeVarType) -> ProperType: else: return self.default(self.s) + def visit_param_spec(self, t: ParamSpecType) -> ProperType: + if self.s == t: + return t + return self.default(self.s) + def visit_instance(self, t: Instance) -> ProperType: if isinstance(self.s, Instance): if self.instance_joiner is None: @@ -445,6 +456,8 @@ def default(self, typ: Type) -> ProperType: return self.default(typ.fallback) elif isinstance(typ, TypeVarType): return self.default(typ.upper_bound) + elif isinstance(typ, ParamSpecType): + return self.default(typ.upper_bound) else: return AnyType(TypeOfAny.special_form) diff --git a/mypy/meet.py b/mypy/meet.py index f89c1fc7b16f..644b57afbcbe 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -5,7 +5,8 @@ Type, AnyType, TypeVisitor, UnboundType, NoneType, TypeVarType, Instance, CallableType, TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, LiteralType, - ProperType, get_proper_type, get_proper_types, TypeAliasType, TypeGuardedType + ProperType, get_proper_type, get_proper_types, TypeAliasType, TypeGuardedType, + ParamSpecType ) from mypy.subtypes import is_equivalent, is_subtype, is_callable_compatible, is_proper_subtype from mypy.erasetype import erase_type @@ -499,6 +500,12 @@ def visit_type_var(self, t: TypeVarType) -> ProperType: else: return self.default(self.s) + def visit_param_spec(self, t: ParamSpecType) -> ProperType: + if self.s == t: + return self.s + else: + return self.default(self.s) + def visit_instance(self, t: Instance) -> ProperType: if isinstance(self.s, Instance): si = self.s diff --git a/mypy/messages.py b/mypy/messages.py index e3b12f49d980..72f53e1cf84e 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -24,7 +24,7 @@ Type, CallableType, Instance, TypeVarType, TupleType, TypedDictType, LiteralType, UnionType, NoneType, AnyType, Overloaded, FunctionLike, DeletedType, TypeType, UninhabitedType, TypeOfAny, UnboundType, PartialType, get_proper_type, ProperType, - get_proper_types + ParamSpecType, get_proper_types ) from mypy.typetraverser import TypeTraverserVisitor from mypy.nodes import ( @@ -1693,6 +1693,8 @@ def format(typ: Type) -> str: elif isinstance(typ, TypeVarType): # This is similar to non-generic instance types. return typ.name + elif isinstance(typ, ParamSpecType): + return typ.name_with_suffix() elif isinstance(typ, TupleType): # Prefer the name of the fallback class (if not tuple), as it's more informative. if typ.partial_fallback.type.fullname != 'builtins.tuple': @@ -1760,6 +1762,9 @@ def format(typ: Type) -> str: return_type = format(func.ret_type) if func.is_ellipsis_args: return 'Callable[..., {}]'.format(return_type) + param_spec = func.param_spec() + if param_spec is not None: + return f'Callable[{param_spec.name}, {return_type}]' arg_strings = [] for arg_name, arg_type, arg_kind in zip( func.arg_names, func.arg_types, func.arg_kinds): diff --git a/mypy/nodes.py b/mypy/nodes.py index 0e20457a4fab..ac2dbf336634 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -963,7 +963,7 @@ class ClassDef(Statement): name: str # Name of the class without module prefix fullname: Bogus[str] # Fully qualified name of the class defs: "Block" - type_vars: List["mypy.types.TypeVarType"] + type_vars: List["mypy.types.TypeVarLikeType"] # Base class expressions (not semantically analyzed -- can be arbitrary expressions) base_type_exprs: List[Expression] # Special base classes like Generic[...] get moved here during semantic analysis @@ -978,7 +978,7 @@ class ClassDef(Statement): def __init__(self, name: str, defs: 'Block', - type_vars: Optional[List['mypy.types.TypeVarType']] = None, + type_vars: Optional[List['mypy.types.TypeVarLikeType']] = None, base_type_exprs: Optional[List[Expression]] = None, metaclass: Optional[Expression] = None, keywords: Optional[List[Tuple[str, Expression]]] = None) -> None: diff --git a/mypy/sametypes.py b/mypy/sametypes.py index 020bda775b59..33cd7f0606cf 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -4,7 +4,8 @@ Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType, UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType, Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType, - ProperType, get_proper_type, TypeAliasType) + ProperType, get_proper_type, TypeAliasType, ParamSpecType +) from mypy.typeops import tuple_fallback, make_simplified_union @@ -96,6 +97,11 @@ def visit_type_var(self, left: TypeVarType) -> bool: return (isinstance(self.right, TypeVarType) and left.id == self.right.id) + def visit_param_spec(self, left: ParamSpecType) -> bool: + # Ignore upper bound since it's derived from flavor. + return (isinstance(self.right, ParamSpecType) and + left.id == self.right.id and left.flavor == self.right.flavor) + def visit_callable_type(self, left: CallableType) -> bool: # FIX generics if isinstance(self.right, CallableType): diff --git a/mypy/semanal.py b/mypy/semanal.py index 3e83f0f95899..86183704f680 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -77,8 +77,7 @@ REVEAL_LOCALS, is_final_node, TypedDictExpr, type_aliases_source_versions, typing_extensions_aliases, EnumCallExpr, RUNTIME_PROTOCOL_DECOS, FakeExpression, Statement, AssignmentExpr, - ParamSpecExpr, EllipsisExpr, - FuncBase, implicit_module_attrs, + ParamSpecExpr, EllipsisExpr, TypeVarLikeExpr, FuncBase, implicit_module_attrs, ) from mypy.tvar_scope import TypeVarLikeScope from mypy.typevars import fill_typevars @@ -93,7 +92,7 @@ FunctionLike, UnboundType, TypeVarType, TupleType, UnionType, StarType, CallableType, Overloaded, Instance, Type, AnyType, LiteralType, LiteralValue, TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType, - get_proper_type, get_proper_types, TypeAliasType, + get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType ) from mypy.typeops import function_type from mypy.type_visitor import TypeQuery @@ -1129,7 +1128,8 @@ def analyze_class(self, defn: ClassDef) -> None: defn, bases, context=defn) for tvd in tvar_defs: - if any(has_placeholder(t) for t in [tvd.upper_bound] + tvd.values): + if (isinstance(tvd, TypeVarType) + and any(has_placeholder(t) for t in [tvd.upper_bound] + tvd.values)): # Some type variable bounds or values are not ready, we need # to re-analyze this class. self.defer() @@ -1291,7 +1291,7 @@ def clean_up_bases_and_infer_type_variables( defn: ClassDef, base_type_exprs: List[Expression], context: Context) -> Tuple[List[Expression], - List[TypeVarType], + List[TypeVarLikeType], bool]: """Remove extra base classes such as Generic and infer type vars. @@ -1352,13 +1352,10 @@ class Foo(Bar, Generic[T]): ... # grained incremental mode. defn.removed_base_type_exprs.append(defn.base_type_exprs[i]) del base_type_exprs[i] - tvar_defs: List[TypeVarType] = [] + tvar_defs: List[TypeVarLikeType] = [] for name, tvar_expr in declared_tvars: tvar_def = self.tvar_scope.bind_new(name, tvar_expr) - if isinstance(tvar_def, TypeVarType): - # TODO(PEP612): fix for ParamSpecType - # Error will be reported elsewhere: #11218 - tvar_defs.append(tvar_def) + tvar_defs.append(tvar_def) return base_type_exprs, tvar_defs, is_protocol def analyze_class_typevar_declaration( @@ -1395,13 +1392,18 @@ def analyze_class_typevar_declaration( return tvars, is_proto return None - def analyze_unbound_tvar(self, t: Type) -> Optional[Tuple[str, TypeVarExpr]]: + def analyze_unbound_tvar(self, t: Type) -> Optional[Tuple[str, TypeVarLikeExpr]]: if not isinstance(t, UnboundType): return None unbound = t sym = self.lookup_qualified(unbound.name, unbound) if sym and isinstance(sym.node, PlaceholderNode): self.record_incomplete_ref() + if sym and isinstance(sym.node, ParamSpecExpr): + if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname): + # It's bound by our type variable scope + return None + return unbound.name, sym.node if sym is None or not isinstance(sym.node, TypeVarExpr): return None elif sym.fullname and not self.tvar_scope.allow_binding(sym.fullname): diff --git a/mypy/semanal_typeargs.py b/mypy/semanal_typeargs.py index d86434e62a19..50e01145d60a 100644 --- a/mypy/semanal_typeargs.py +++ b/mypy/semanal_typeargs.py @@ -9,7 +9,8 @@ from mypy.nodes import TypeInfo, Context, MypyFile, FuncItem, ClassDef, Block, FakeInfo from mypy.types import ( - Type, Instance, TypeVarType, AnyType, get_proper_types, TypeAliasType, get_proper_type + Type, Instance, TypeVarType, AnyType, get_proper_types, TypeAliasType, ParamSpecType, + get_proper_type ) from mypy.mixedtraverser import MixedTraverserVisitor from mypy.subtypes import is_subtype @@ -70,23 +71,28 @@ def visit_instance(self, t: Instance) -> None: if isinstance(info, FakeInfo): return # https://github.com/python/mypy/issues/11079 for (i, arg), tvar in zip(enumerate(t.args), info.defn.type_vars): - if tvar.values: - if isinstance(arg, TypeVarType): - arg_values = arg.values - if not arg_values: - self.fail( - message_registry.INVALID_TYPEVAR_AS_TYPEARG.format( - arg.name, info.name), - t, code=codes.TYPE_VAR) - continue - else: - arg_values = [arg] - self.check_type_var_values(info, arg_values, tvar.name, tvar.values, i + 1, t) - if not is_subtype(arg, tvar.upper_bound): - self.fail( - message_registry.INVALID_TYPEVAR_ARG_BOUND.format( - format_type(arg), info.name, format_type(tvar.upper_bound)), - t, code=codes.TYPE_VAR) + if isinstance(tvar, TypeVarType): + if isinstance(arg, ParamSpecType): + # TODO: Better message + self.fail(f'Invalid location for ParamSpec "{arg.name}"', t) + continue + if tvar.values: + if isinstance(arg, TypeVarType): + arg_values = arg.values + if not arg_values: + self.fail( + message_registry.INVALID_TYPEVAR_AS_TYPEARG.format( + arg.name, info.name), + t, code=codes.TYPE_VAR) + continue + else: + arg_values = [arg] + self.check_type_var_values(info, arg_values, tvar.name, tvar.values, i + 1, t) + if not is_subtype(arg, tvar.upper_bound): + self.fail( + message_registry.INVALID_TYPEVAR_ARG_BOUND.format( + format_type(arg), info.name, format_type(tvar.upper_bound)), + t, code=codes.TYPE_VAR) super().visit_instance(t) def check_type_var_values(self, type: TypeInfo, actuals: List[Type], arg_name: str, diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index bf82c0385c3a..12add8efcb3a 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -59,7 +59,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method' from mypy.types import ( Type, TypeVisitor, UnboundType, AnyType, NoneType, UninhabitedType, ErasedType, DeletedType, Instance, TypeVarType, CallableType, TupleType, TypedDictType, - UnionType, Overloaded, PartialType, TypeType, LiteralType, TypeAliasType + UnionType, Overloaded, PartialType, TypeType, LiteralType, TypeAliasType, ParamSpecType ) from mypy.util import get_prefix @@ -310,6 +310,13 @@ def visit_type_var(self, typ: TypeVarType) -> SnapshotItem: snapshot_type(typ.upper_bound), typ.variance) + def visit_param_spec(self, typ: ParamSpecType) -> SnapshotItem: + return ('ParamSpec', + typ.id.raw_id, + typ.id.meta_level, + typ.flavor, + snapshot_type(typ.upper_bound)) + def visit_callable_type(self, typ: CallableType) -> SnapshotItem: # FIX generics return ('CallableType', diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py index 11ce494fba0e..8db2b302b844 100644 --- a/mypy/server/astmerge.py +++ b/mypy/server/astmerge.py @@ -59,7 +59,7 @@ Type, SyntheticTypeVisitor, Instance, AnyType, NoneType, CallableType, ErasedType, DeletedType, TupleType, TypeType, TypedDictType, UnboundType, UninhabitedType, UnionType, Overloaded, TypeVarType, TypeList, CallableArgument, EllipsisType, StarType, LiteralType, - RawExpressionType, PartialType, PlaceholderType, TypeAliasType + RawExpressionType, PartialType, PlaceholderType, TypeAliasType, ParamSpecType ) from mypy.util import get_prefix, replace_object_state from mypy.typestate import TypeState @@ -173,7 +173,8 @@ def visit_class_def(self, node: ClassDef) -> None: node.defs.body = self.replace_statements(node.defs.body) info = node.info for tv in node.type_vars: - self.process_type_var_def(tv) + if isinstance(tv, TypeVarType): + self.process_type_var_def(tv) if info: if info.is_named_tuple: self.process_synthetic_type_info(info) @@ -407,6 +408,9 @@ def visit_type_var(self, typ: TypeVarType) -> None: for value in typ.values: value.accept(self) + def visit_param_spec(self, typ: ParamSpecType) -> None: + pass + def visit_typeddict_type(self, typ: TypedDictType) -> None: for value_type in typ.items.values(): value_type.accept(self) diff --git a/mypy/server/deps.py b/mypy/server/deps.py index f80673fdb7d4..acec79c103bf 100644 --- a/mypy/server/deps.py +++ b/mypy/server/deps.py @@ -99,7 +99,7 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a Type, Instance, AnyType, NoneType, TypeVisitor, CallableType, DeletedType, PartialType, TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType, FunctionLike, Overloaded, TypeOfAny, LiteralType, ErasedType, get_proper_type, ProperType, - TypeAliasType + TypeAliasType, ParamSpecType ) from mypy.server.trigger import make_trigger, make_wildcard_trigger from mypy.util import correct_relative_import @@ -951,6 +951,13 @@ def visit_type_var(self, typ: TypeVarType) -> List[str]: triggers.extend(self.get_type_triggers(val)) return triggers + def visit_param_spec(self, typ: ParamSpecType) -> List[str]: + triggers = [] + if typ.fullname: + triggers.append(make_trigger(typ.fullname)) + triggers.extend(self.get_type_triggers(typ.upper_bound)) + return triggers + def visit_typeddict_type(self, typ: TypedDictType) -> List[str]: triggers = [] for item in typ.items.values(): diff --git a/mypy/subtypes.py b/mypy/subtypes.py index a3a678a941e7..ef117925a817 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -7,7 +7,7 @@ Type, AnyType, UnboundType, TypeVisitor, FormalArgument, NoneType, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, is_named_instance, - FunctionLike, TypeOfAny, LiteralType, get_proper_type, TypeAliasType + FunctionLike, TypeOfAny, LiteralType, get_proper_type, TypeAliasType, ParamSpecType ) import mypy.applytype import mypy.constraints @@ -272,9 +272,14 @@ def visit_instance(self, left: Instance) -> bool: and not self.ignore_declared_variance): # Map left type to corresponding right instances. t = map_instance_to_supertype(left, right.type) - nominal = all(self.check_type_parameter(lefta, righta, tvar.variance) - for lefta, righta, tvar in - zip(t.args, right.args, right.type.defn.type_vars)) + nominal = True + for lefta, righta, tvar in zip(t.args, right.args, right.type.defn.type_vars): + if isinstance(tvar, TypeVarType): + if not self.check_type_parameter(lefta, righta, tvar.variance): + nominal = False + else: + if not is_equivalent(lefta, righta): + nominal = False if nominal: TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) return nominal @@ -312,6 +317,16 @@ def visit_type_var(self, left: TypeVarType) -> bool: return True return self._is_subtype(left.upper_bound, self.right) + def visit_param_spec(self, left: ParamSpecType) -> bool: + right = self.right + if ( + isinstance(right, ParamSpecType) + and right.id == left.id + and right.flavor == left.flavor + ): + return True + return self._is_subtype(left.upper_bound, self.right) + def visit_callable_type(self, left: CallableType) -> bool: right = self.right if isinstance(right, CallableType): @@ -1306,8 +1321,13 @@ def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: assert isinstance(erased, Instance) left = erased - nominal = all(check_argument(ta, ra, tvar.variance) for ta, ra, tvar in - zip(left.args, right.args, right.type.defn.type_vars)) + nominal = True + for ta, ra, tvar in zip(left.args, right.args, right.type.defn.type_vars): + if isinstance(tvar, TypeVarType): + nominal = nominal and check_argument(ta, ra, tvar.variance) + else: + nominal = nominal and mypy.sametypes.is_same_type(ta, ra) + if nominal: TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) return nominal @@ -1330,6 +1350,16 @@ def visit_type_var(self, left: TypeVarType) -> bool: return True return self._is_proper_subtype(left.upper_bound, self.right) + def visit_param_spec(self, left: ParamSpecType) -> bool: + right = self.right + if ( + isinstance(right, ParamSpecType) + and right.id == left.id + and right.flavor == left.flavor + ): + return True + return self._is_proper_subtype(left.upper_bound, self.right) + def visit_callable_type(self, left: CallableType) -> bool: right = self.right if isinstance(right, CallableType): diff --git a/mypy/test/typefixture.py b/mypy/test/typefixture.py index 3b12caebae4b..1a5dd8164136 100644 --- a/mypy/test/typefixture.py +++ b/mypy/test/typefixture.py @@ -8,7 +8,8 @@ from mypy.semanal_shared import set_callable_name from mypy.types import ( Type, AnyType, NoneType, Instance, CallableType, TypeVarType, TypeType, - UninhabitedType, TypeOfAny, TypeAliasType, UnionType, LiteralType + UninhabitedType, TypeOfAny, TypeAliasType, UnionType, LiteralType, + TypeVarLikeType ) from mypy.nodes import ( TypeInfo, ClassDef, FuncDef, Block, ARG_POS, ARG_OPT, ARG_STAR, SymbolTable, @@ -234,7 +235,7 @@ def make_type_info(self, name: str, module_name = '__main__' if typevars: - v: List[TypeVarType] = [] + v: List[TypeVarLikeType] = [] for id, n in enumerate(typevars, 1): if variances: variance = variances[id - 1] diff --git a/mypy/tvar_scope.py b/mypy/tvar_scope.py index 9b5c1ebfb654..0d8be7845e52 100644 --- a/mypy/tvar_scope.py +++ b/mypy/tvar_scope.py @@ -1,5 +1,5 @@ from typing import Optional, Dict, Union -from mypy.types import TypeVarLikeType, TypeVarType, ParamSpecType +from mypy.types import TypeVarLikeType, TypeVarType, ParamSpecType, ParamSpecFlavor from mypy.nodes import ParamSpecExpr, TypeVarExpr, TypeVarLikeExpr, SymbolTableNode @@ -78,6 +78,8 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType: name, tvar_expr.fullname, i, + flavor=ParamSpecFlavor.BARE, + upper_bound=tvar_expr.upper_bound, line=tvar_expr.line, column=tvar_expr.column ) diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 2b4ebffb93e0..99821fa62640 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -23,7 +23,7 @@ RawExpressionType, Instance, NoneType, TypeType, UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarLikeType, UnboundType, ErasedType, StarType, EllipsisType, TypeList, CallableArgument, - PlaceholderType, TypeAliasType, get_proper_type + PlaceholderType, TypeAliasType, ParamSpecType, get_proper_type ) @@ -63,6 +63,10 @@ def visit_deleted_type(self, t: DeletedType) -> T: def visit_type_var(self, t: TypeVarType) -> T: pass + @abstractmethod + def visit_param_spec(self, t: ParamSpecType) -> T: + pass + @abstractmethod def visit_instance(self, t: Instance) -> T: pass @@ -179,6 +183,9 @@ def visit_instance(self, t: Instance) -> Type: def visit_type_var(self, t: TypeVarType) -> Type: return t + def visit_param_spec(self, t: ParamSpecType) -> Type: + return t + def visit_partial_type(self, t: PartialType) -> Type: return t @@ -291,6 +298,9 @@ def visit_deleted_type(self, t: DeletedType) -> T: def visit_type_var(self, t: TypeVarType) -> T: return self.query_types([t.upper_bound] + t.values) + def visit_param_spec(self, t: ParamSpecType) -> T: + return self.strategy([]) + def visit_partial_type(self, t: PartialType) -> T: return self.strategy([]) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index d400c7e1ca69..7cbf5d869a6e 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -17,7 +17,7 @@ StarType, PartialType, EllipsisType, UninhabitedType, TypeType, CallableArgument, TypeQuery, union_items, TypeOfAny, LiteralType, RawExpressionType, PlaceholderType, Overloaded, get_proper_type, TypeAliasType, - TypeVarLikeType, ParamSpecType + TypeVarLikeType, ParamSpecType, ParamSpecFlavor, callable_with_ellipsis ) from mypy.nodes import ( @@ -209,13 +209,14 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) if tvar_def is None: self.fail('ParamSpec "{}" is unbound'.format(t.name), t) return AnyType(TypeOfAny.from_error) - self.fail('Invalid location for ParamSpec "{}"'.format(t.name), t) - self.note( - 'You can use ParamSpec as the first argument to Callable, e.g., ' - "'Callable[{}, int]'".format(t.name), - t + assert isinstance(tvar_def, ParamSpecType) + if len(t.args) > 0: + self.fail('ParamSpec "{}" used with arguments'.format(t.name), t) + # Change the line number + return ParamSpecType( + tvar_def.name, tvar_def.fullname, tvar_def.id, tvar_def.flavor, + tvar_def.upper_bound, line=t.line, column=t.column, ) - return AnyType(TypeOfAny.from_error) if isinstance(sym.node, TypeVarExpr) and tvar_def is not None and self.defining_alias: self.fail('Can\'t use bound type variable "{}"' ' to define generic alias'.format(t.name), t) @@ -381,7 +382,8 @@ def analyze_type_with_type_info( # checked only later, since we do not always know the # valid count at this point. Thus we may construct an # Instance with an invalid number of type arguments. - instance = Instance(info, self.anal_array(args), ctx.line, ctx.column) + instance = Instance(info, self.anal_array(args, allow_param_spec=True), + ctx.line, ctx.column) # Check type argument count. if len(instance.args) != len(info.type_vars) and not self.defining_alias: fix_instance(instance, self.fail, self.note, @@ -531,6 +533,9 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type: def visit_type_var(self, t: TypeVarType) -> Type: return t + def visit_param_spec(self, t: ParamSpecType) -> Type: + return t + def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: # Every Callable can bind its own type variables, if they're not in the outer scope with self.tvar_scope_frame(): @@ -539,7 +544,15 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: else: variables = self.bind_function_type_variables(t, t) special = self.anal_type_guard(t.ret_type) - ret = t.copy_modified(arg_types=self.anal_array(t.arg_types, nested=nested), + arg_kinds = t.arg_kinds + if len(arg_kinds) >= 2 and arg_kinds[-2] == ARG_STAR and arg_kinds[-1] == ARG_STAR2: + arg_types = self.anal_array(t.arg_types[:-2], nested=nested) + [ + self.anal_star_arg_type(t.arg_types[-2], ARG_STAR, nested=nested), + self.anal_star_arg_type(t.arg_types[-1], ARG_STAR2, nested=nested), + ] + else: + arg_types = self.anal_array(t.arg_types, nested=nested) + ret = t.copy_modified(arg_types=arg_types, ret_type=self.anal_type(t.ret_type, nested=nested), # If the fallback isn't filled in yet, # its type will be the falsey FakeInfo @@ -566,6 +579,26 @@ def anal_type_guard_arg(self, t: UnboundType, fullname: str) -> Optional[Type]: return self.anal_type(t.args[0]) return None + def anal_star_arg_type(self, t: Type, kind: ArgKind, nested: bool) -> Type: + """Analyze signature argument type for *args and **kwargs argument.""" + # TODO: Check that suffix and kind match + if isinstance(t, UnboundType) and t.name and '.' in t.name and not t.args: + components = t.name.split('.') + sym = self.lookup_qualified('.'.join(components[:-1]), t) + if sym is not None and isinstance(sym.node, ParamSpecExpr): + tvar_def = self.tvar_scope.get_binding(sym) + if isinstance(tvar_def, ParamSpecType): + if kind == ARG_STAR: + flavor = ParamSpecFlavor.ARGS + elif kind == ARG_STAR2: + flavor = ParamSpecFlavor.KWARGS + else: + assert False, kind + return ParamSpecType(tvar_def.name, tvar_def.fullname, tvar_def.id, flavor, + upper_bound=self.named_type('builtins.object'), + line=t.line, column=t.column) + return self.anal_type(t, nested=nested) + def visit_overloaded(self, t: Overloaded) -> Type: # Overloaded types are manually constructed in semanal.py by analyzing the # AST and combining together the Callable types this visitor converts. @@ -694,14 +727,17 @@ def analyze_callable_args_for_paramspec( if not isinstance(tvar_def, ParamSpecType): return None - # TODO(PEP612): construct correct type for paramspec + # TODO: Use tuple[...] or Mapping[..] instead? + obj = self.named_type('builtins.object') return CallableType( - [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)], + [ParamSpecType(tvar_def.name, tvar_def.fullname, tvar_def.id, ParamSpecFlavor.ARGS, + upper_bound=obj), + ParamSpecType(tvar_def.name, tvar_def.fullname, tvar_def.id, ParamSpecFlavor.KWARGS, + upper_bound=obj)], [nodes.ARG_STAR, nodes.ARG_STAR2], [None, None], ret_type=ret_type, fallback=fallback, - is_ellipsis_args=True ) def analyze_callable_type(self, t: UnboundType) -> Type: @@ -709,12 +745,7 @@ def analyze_callable_type(self, t: UnboundType) -> Type: if len(t.args) == 0: # Callable (bare). Treat as Callable[..., Any]. any_type = self.get_omitted_any(t) - ret = CallableType([any_type, any_type], - [nodes.ARG_STAR, nodes.ARG_STAR2], - [None, None], - ret_type=any_type, - fallback=fallback, - is_ellipsis_args=True) + ret = callable_with_ellipsis(any_type, any_type, fallback) elif len(t.args) == 2: callable_args = t.args[0] ret_type = t.args[1] @@ -731,13 +762,9 @@ def analyze_callable_type(self, t: UnboundType) -> Type: fallback=fallback) elif isinstance(callable_args, EllipsisType): # Callable[..., RET] (with literal ellipsis; accept arbitrary arguments) - ret = CallableType([AnyType(TypeOfAny.explicit), - AnyType(TypeOfAny.explicit)], - [nodes.ARG_STAR, nodes.ARG_STAR2], - [None, None], - ret_type=ret_type, - fallback=fallback, - is_ellipsis_args=True) + ret = callable_with_ellipsis(AnyType(TypeOfAny.explicit), + ret_type=ret_type, + fallback=fallback) else: # Callable[P, RET] (where P is ParamSpec) maybe_ret = self.analyze_callable_args_for_paramspec( @@ -956,20 +983,33 @@ def is_defined_type_var(self, tvar: str, context: Context) -> bool: return False return self.tvar_scope.get_binding(tvar_node) is not None - def anal_array(self, a: Iterable[Type], nested: bool = True) -> List[Type]: + def anal_array(self, + a: Iterable[Type], + nested: bool = True, *, + allow_param_spec: bool = False) -> List[Type]: res: List[Type] = [] for t in a: - res.append(self.anal_type(t, nested)) + res.append(self.anal_type(t, nested, allow_param_spec=allow_param_spec)) return res - def anal_type(self, t: Type, nested: bool = True) -> Type: + def anal_type(self, t: Type, nested: bool = True, *, allow_param_spec: bool = False) -> Type: if nested: self.nesting_level += 1 try: - return t.accept(self) + analyzed = t.accept(self) finally: if nested: self.nesting_level -= 1 + if (not allow_param_spec + and isinstance(analyzed, ParamSpecType) + and analyzed.flavor == ParamSpecFlavor.BARE): + self.fail('Invalid location for ParamSpec "{}"'.format(analyzed.name), t) + self.note( + 'You can use ParamSpec as the first argument to Callable, e.g., ' + "'Callable[{}, int]'".format(analyzed.name), + t + ) + return analyzed def anal_var_def(self, var_def: TypeVarLikeType) -> TypeVarLikeType: if isinstance(var_def, TypeVarType): @@ -1187,6 +1227,7 @@ def flatten_tvars(ll: Iterable[List[T]]) -> List[T]: class TypeVarLikeQuery(TypeQuery[TypeVarLikeList]): + """Find TypeVar and ParamSpec references in an unbound type.""" def __init__(self, lookup: Callable[[str, Context], Optional[SymbolTableNode]], @@ -1209,7 +1250,17 @@ def _seems_like_callable(self, type: UnboundType) -> bool: def visit_unbound_type(self, t: UnboundType) -> TypeVarLikeList: name = t.name - node = self.lookup(name, t) + node = None + # Special case P.args and P.kwargs for ParamSpecs only. + if name.endswith('args'): + if name.endswith('.args') or name.endswith('.kwargs'): + base = '.'.join(name.split('.')[:-1]) + n = self.lookup(base, t) + if n is not None and isinstance(n.node, ParamSpecExpr): + node = n + name = base + if node is None: + node = self.lookup(name, t) if node and isinstance(node.node, TypeVarLikeExpr) and ( self.include_bound_tvars or self.scope.get_binding(node) is None): assert isinstance(node.node, TypeVarLikeExpr) diff --git a/mypy/types.py b/mypy/types.py index c709a96ab204..81b6de37acd6 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -3,7 +3,6 @@ import copy import sys from abc import abstractmethod -from mypy.backports import OrderedDict from typing import ( Any, TypeVar, Dict, List, Tuple, cast, Set, Optional, Union, Iterable, NamedTuple, @@ -11,6 +10,7 @@ ) from typing_extensions import ClassVar, Final, TYPE_CHECKING, overload, TypeAlias as _TypeAlias +from mypy.backports import OrderedDict import mypy.nodes from mypy import state from mypy.nodes import ( @@ -363,14 +363,16 @@ def is_meta_var(self) -> bool: class TypeVarLikeType(ProperType): - __slots__ = ('name', 'fullname', 'id') + __slots__ = ('name', 'fullname', 'id', 'upper_bound') name: str # Name (may be qualified) fullname: str # Fully qualified name id: TypeVarId + upper_bound: Type def __init__( - self, name: str, fullname: str, id: Union[TypeVarId, int], line: int = -1, column: int = -1 + self, name: str, fullname: str, id: Union[TypeVarId, int], upper_bound: Type, + line: int = -1, column: int = -1 ) -> None: super().__init__(line, column) self.name = name @@ -378,6 +380,7 @@ def __init__( if isinstance(id, int): id = TypeVarId(id) self.id = id + self.upper_bound = upper_bound def serialize(self) -> JsonDict: raise NotImplementedError @@ -388,21 +391,19 @@ def deserialize(cls, data: JsonDict) -> 'TypeVarLikeType': class TypeVarType(TypeVarLikeType): - """Definition of a single type variable.""" + """Type that refers to a type variable.""" - __slots__ = ('values', 'upper_bound', 'variance') + __slots__ = ('values', 'variance') values: List[Type] # Value restriction, empty list if no restriction - upper_bound: Type variance: int def __init__(self, name: str, fullname: str, id: Union[TypeVarId, int], values: List[Type], upper_bound: Type, variance: int = INVARIANT, line: int = -1, column: int = -1) -> None: - super().__init__(name, fullname, id, line, column) + super().__init__(name, fullname, id, upper_bound, line, column) assert values is not None, "No restrictions must be represented by empty list" self.values = values - self.upper_bound = upper_bound self.variance = variance @staticmethod @@ -446,13 +447,69 @@ def deserialize(cls, data: JsonDict) -> 'TypeVarType': ) +class ParamSpecFlavor: + # Simple ParamSpec reference such as "P" + BARE: Final = 0 + # P.args + ARGS: Final = 1 + # P.kwargs + KWARGS: Final = 2 + + class ParamSpecType(TypeVarLikeType): - """Definition of a single ParamSpec variable.""" + """Type that refers to a ParamSpec. - __slots__ = () + A ParamSpec is a type variable that represents the parameter + types, names and kinds of a callable (i.e., the signature without + the return type). - def __repr__(self) -> str: - return self.name + This can be one of these forms + * P (ParamSpecFlavor.BARE) + * P.args (ParamSpecFlavor.ARGS) + * P.kwargs (ParamSpecFLavor.KWARGS) + + The upper_bound is really used as a fallback type -- it's shared + with TypeVarType for simplicity. It can't be specified by the user + and the value is directly derived from the flavor (currently + always just 'object'). + """ + + __slots__ = ('flavor',) + + flavor: int + + def __init__( + self, name: str, fullname: str, id: Union[TypeVarId, int], flavor: int, + upper_bound: Type, *, line: int = -1, column: int = -1 + ) -> None: + super().__init__(name, fullname, id, upper_bound, line=line, column=column) + self.flavor = flavor + + @staticmethod + def new_unification_variable(old: 'ParamSpecType') -> 'ParamSpecType': + new_id = TypeVarId.new(meta_level=1) + return ParamSpecType(old.name, old.fullname, new_id, old.flavor, old.upper_bound, + line=old.line, column=old.column) + + def accept(self, visitor: 'TypeVisitor[T]') -> T: + return visitor.visit_param_spec(self) + + def name_with_suffix(self) -> str: + n = self.name + if self.flavor == ParamSpecFlavor.ARGS: + return f'{n}.args' + elif self.flavor == ParamSpecFlavor.KWARGS: + return f'{n}.kwargs' + return n + + def __hash__(self) -> int: + return hash((self.id, self.flavor)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ParamSpecType): + return NotImplemented + # Upper bound can be ignored, since it's determined by flavor. + return self.id == other.id and self.flavor == other.flavor def serialize(self) -> JsonDict: assert not self.id.is_meta_var() @@ -461,6 +518,8 @@ def serialize(self) -> JsonDict: 'name': self.name, 'fullname': self.fullname, 'id': self.id.raw_id, + 'flavor': self.flavor, + 'upper_bound': self.upper_bound.serialize(), } @classmethod @@ -470,6 +529,8 @@ def deserialize(cls, data: JsonDict) -> 'ParamSpecType': data['name'], data['fullname'], data['id'], + data['flavor'], + deserialize_type(data['upper_bound']), ) @@ -1254,6 +1315,27 @@ def type_var_ids(self) -> List[TypeVarId]: a.append(tv.id) return a + def param_spec(self) -> Optional[ParamSpecType]: + """Return ParamSpec if callable can be called with one. + + A Callable accepting ParamSpec P args (*args, **kwargs) must have the + two final parameters like this: *args: P.args, **kwargs: P.kwargs. + """ + if len(self.arg_types) < 2: + return None + if self.arg_kinds[-2] != ARG_STAR or self.arg_kinds[-1] != ARG_STAR2: + return None + arg_type = self.arg_types[-2] + if not isinstance(arg_type, ParamSpecType): + return None + return ParamSpecType(arg_type.name, arg_type.fullname, arg_type.id, ParamSpecFlavor.BARE, + arg_type.upper_bound) + + def expand_param_spec(self, c: 'CallableType') -> 'CallableType': + return self.copy_modified(arg_types=self.arg_types[:-2] + c.arg_types, + arg_kinds=self.arg_kinds[:-2] + c.arg_kinds, + arg_names=self.arg_names[:-2] + c.arg_names) + def __hash__(self) -> int: return hash((self.ret_type, self.is_type_obj(), self.is_ellipsis_args, self.name, @@ -2118,10 +2200,25 @@ def visit_type_var(self, t: TypeVarType) -> str: s += '(upper_bound={})'.format(t.upper_bound.accept(self)) return s + def visit_param_spec(self, t: ParamSpecType) -> str: + if t.name is None: + # Anonymous type variable type (only numeric id). + s = f'`{t.id}' + else: + # Named type variable type. + s = f'{t.name_with_suffix()}`{t.id}' + return s + def visit_callable_type(self, t: CallableType) -> str: + param_spec = t.param_spec() + if param_spec is not None: + num_skip = 2 + else: + num_skip = 0 + s = '' bare_asterisk = False - for i in range(len(t.arg_types)): + for i in range(len(t.arg_types) - num_skip): if s != '': s += ', ' if t.arg_kinds[i].is_named() and not bare_asterisk: @@ -2138,6 +2235,12 @@ def visit_callable_type(self, t: CallableType) -> str: if t.arg_kinds[i].is_optional(): s += ' =' + if param_spec is not None: + n = param_spec.name + if s: + s += ', ' + s += f'*{n}.args, **{n}.kwargs' + s = '({})'.format(s) if not isinstance(get_proper_type(t.ret_type), NoneType): @@ -2159,8 +2262,8 @@ def visit_callable_type(self, t: CallableType) -> str: else: vs.append(var.name) else: - # For other TypeVarLikeTypes, just use the repr - vs.append(repr(var)) + # For other TypeVarLikeTypes, just use the name + vs.append(var.name) s = '{} {}'.format('[{}]'.format(', '.join(vs)), s) return 'def {}'.format(s) @@ -2419,3 +2522,15 @@ def is_literal_type(typ: ProperType, fallback_fullname: str, value: LiteralValue for key, obj in names.items() if isinstance(obj, type) and issubclass(obj, Type) and obj is not Type } + + +def callable_with_ellipsis(any_type: AnyType, + ret_type: Type, + fallback: Instance) -> CallableType: + """Construct type Callable[..., ret_type].""" + return CallableType([any_type, any_type], + [ARG_STAR, ARG_STAR2], + [None, None], + ret_type=ret_type, + fallback=fallback, + is_ellipsis_args=True) diff --git a/mypy/typetraverser.py b/mypy/typetraverser.py index 3bebd3831971..a03784b0406e 100644 --- a/mypy/typetraverser.py +++ b/mypy/typetraverser.py @@ -6,7 +6,7 @@ Type, SyntheticTypeVisitor, AnyType, UninhabitedType, NoneType, ErasedType, DeletedType, TypeVarType, LiteralType, Instance, CallableType, TupleType, TypedDictType, UnionType, Overloaded, TypeType, CallableArgument, UnboundType, TypeList, StarType, EllipsisType, - PlaceholderType, PartialType, RawExpressionType, TypeAliasType + PlaceholderType, PartialType, RawExpressionType, TypeAliasType, ParamSpecType ) @@ -37,6 +37,9 @@ def visit_type_var(self, t: TypeVarType) -> None: # definition. We want to traverse everything just once. pass + def visit_param_spec(self, t: ParamSpecType) -> None: + pass + def visit_literal_type(self, t: LiteralType) -> None: t.fallback.accept(self) diff --git a/mypy/typevars.py b/mypy/typevars.py index f595551bd3bc..b49194f342e0 100644 --- a/mypy/typevars.py +++ b/mypy/typevars.py @@ -3,7 +3,7 @@ from mypy.nodes import TypeInfo from mypy.erasetype import erase_typevars -from mypy.types import Instance, TypeVarType, TupleType, Type, TypeOfAny, AnyType +from mypy.types import Instance, TypeVarType, TupleType, Type, TypeOfAny, AnyType, ParamSpecType def fill_typevars(typ: TypeInfo) -> Union[Instance, TupleType]: @@ -16,10 +16,15 @@ def fill_typevars(typ: TypeInfo) -> Union[Instance, TupleType]: for i in range(len(typ.defn.type_vars)): tv = typ.defn.type_vars[i] # Change the line number - tv = TypeVarType( - tv.name, tv.fullname, tv.id, tv.values, - tv.upper_bound, tv.variance, line=-1, column=-1, - ) + if isinstance(tv, TypeVarType): + tv = TypeVarType( + tv.name, tv.fullname, tv.id, tv.values, + tv.upper_bound, tv.variance, line=-1, column=-1, + ) + else: + assert isinstance(tv, ParamSpecType) + tv = ParamSpecType(tv.name, tv.fullname, tv.id, tv.flavor, tv.upper_bound, + line=-1, column=-1) tvs.append(tv) inst = Instance(typ, tvs) if typ.tuple_type is None: diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 4f1c917c32ff..c7ed0716a118 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -28,51 +28,303 @@ def foo6(x: Callable[[P], int]) -> None: ... # E: Invalid location for ParamSpe # N: You can use ParamSpec as the first argument to Callable, e.g., 'Callable[P, int]' [builtins fixtures/tuple.pyi] -[case testParamSpecTemporaryAnyBehaviour] -# TODO(PEP612): behaviour tested here should change -# This is a test of mypy's temporary behaviour in lieu of full support for ParamSpec +[case testParamSpecContextManagerLike] from typing import Callable, List, Iterator, TypeVar from typing_extensions import ParamSpec P = ParamSpec('P') T = TypeVar('T') -def changes_return_type_to_str(x: Callable[P, int]) -> Callable[P, str]: ... - -def returns_int(a: str, b: bool) -> int: ... - -reveal_type(changes_return_type_to_str(returns_int)) # N: Revealed type is "def (*Any, **Any) -> builtins.str" - def tmpcontextmanagerlike(x: Callable[P, Iterator[T]]) -> Callable[P, List[T]]: ... @tmpcontextmanagerlike def whatever(x: int) -> Iterator[int]: yield x -reveal_type(whatever) # N: Revealed type is "def (*Any, **Any) -> builtins.list[builtins.int*]" +reveal_type(whatever) # N: Revealed type is "def (x: builtins.int) -> builtins.list[builtins.int*]" reveal_type(whatever(217)) # N: Revealed type is "builtins.list[builtins.int*]" [builtins fixtures/tuple.pyi] -[case testGenericParamSpecTemporaryBehaviour] -# flags: --python-version 3.10 -# TODO(PEP612): behaviour tested here should change -from typing import Generic, TypeVar, Callable, ParamSpec - -T = TypeVar("T") -P = ParamSpec("P") - -class X(Generic[T, P]): # E: Free type variable expected in Generic[...] - f: Callable[P, int] # E: The first argument to Callable must be a list of types or "..." - x: T -[out] - [case testInvalidParamSpecType] # flags: --python-version 3.10 from typing import ParamSpec P = ParamSpec("P") -class MyFunction(P): # E: Invalid location for ParamSpec "P" \ - # N: You can use ParamSpec as the first argument to Callable, e.g., 'Callable[P, int]' +class MyFunction(P): # E: Invalid base class "P" ... -a: MyFunction[int] # E: "MyFunction" expects no type arguments, but 1 given +[case testParamSpecRevealType] +from typing import Callable +from typing_extensions import ParamSpec + +P = ParamSpec('P') + +def f(x: Callable[P, int]) -> None: ... +reveal_type(f) # N: Revealed type is "def [P] (x: def (*P.args, **P.kwargs) -> builtins.int)" +[builtins fixtures/tuple.pyi] + +[case testParamSpecSimpleFunction] +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +P = ParamSpec('P') + +def changes_return_type_to_str(x: Callable[P, int]) -> Callable[P, str]: ... + +def returns_int(a: str, b: bool) -> int: ... + +reveal_type(changes_return_type_to_str(returns_int)) # N: Revealed type is "def (a: builtins.str, b: builtins.bool) -> builtins.str" +[builtins fixtures/tuple.pyi] + +[case testParamSpecSimpleClass] +from typing import Callable, TypeVar, Generic +from typing_extensions import ParamSpec + +P = ParamSpec('P') + +class C(Generic[P]): + def __init__(self, x: Callable[P, None]) -> None: ... + + def m(self, *args: P.args, **kwargs: P.kwargs) -> int: + return 1 + +def f(x: int, y: str) -> None: ... + +reveal_type(C(f)) # N: Revealed type is "__main__.C[def (x: builtins.int, y: builtins.str)]" +reveal_type(C(f).m) # N: Revealed type is "def (x: builtins.int, y: builtins.str) -> builtins.int" +[builtins fixtures/dict.pyi] + +[case testParamSpecClassWithPrefixArgument] +from typing import Callable, TypeVar, Generic +from typing_extensions import ParamSpec + +P = ParamSpec('P') + +class C(Generic[P]): + def __init__(self, x: Callable[P, None]) -> None: ... + + def m(self, a: str, *args: P.args, **kwargs: P.kwargs) -> int: + return 1 + +def f(x: int, y: str) -> None: ... + +reveal_type(C(f).m) # N: Revealed type is "def (a: builtins.str, x: builtins.int, y: builtins.str) -> builtins.int" +reveal_type(C(f).m('', 1, '')) # N: Revealed type is "builtins.int" +[builtins fixtures/dict.pyi] + +[case testParamSpecDecorator] +from typing import Callable, TypeVar, Generic +from typing_extensions import ParamSpec + +P = ParamSpec('P') +R = TypeVar('R') + +class W(Generic[P, R]): + f: Callable[P, R] + x: int + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + reveal_type(self.f(*args, **kwargs)) # N: Revealed type is "R`2" + return self.f(*args, **kwargs) + +def dec() -> Callable[[Callable[P, R]], W[P, R]]: + pass + +@dec() +def f(a: int, b: str) -> None: ... + +reveal_type(f) # N: Revealed type is "__main__.W[def (a: builtins.int, b: builtins.str), None]" +reveal_type(f(1, '')) # N: Revealed type is "None" +reveal_type(f.x) # N: Revealed type is "builtins.int" + +## TODO: How should this work? +# +# class C: +# @dec() +# def m(self, x: int) -> str: ... +# +# reveal_type(C().m(x=1)) +[builtins fixtures/dict.pyi] + +[case testParamSpecFunction] +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +P = ParamSpec('P') +R = TypeVar('R') + +def f(x: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + return x(*args, **kwargs) + +def g(x: int, y: str) -> None: ... + +reveal_type(f(g, 1, y='x')) # N: Revealed type is "None" +f(g, 'x', y='x') # E: Argument 2 to "f" has incompatible type "str"; expected "int" +f(g, 1, y=1) # E: Argument "y" to "f" has incompatible type "int"; expected "str" +f(g) # E: Missing positional arguments "x", "y" in call to "f" + +[builtins fixtures/dict.pyi] + +[case testParamSpecSpecialCase] +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +P = ParamSpec('P') +T = TypeVar('T') + +def register(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Callable[P, T]: ... + +def f(x: int, y: str, z: int, a: str) -> None: ... + +x = register(f, 1, '', 1, '') +[builtins fixtures/dict.pyi] + +[case testParamSpecInferredFromAny] +from typing import Callable, Any +from typing_extensions import ParamSpec + +P = ParamSpec('P') + +def f(x: Callable[P, int]) -> Callable[P, str]: ... + +g: Any +reveal_type(f(g)) # N: Revealed type is "def (*Any, **Any) -> builtins.str" + +f(g)(1, 3, x=1, y=2) +[builtins fixtures/tuple.pyi] + +[case testParamSpecDecoratorImplementation] +from typing import Callable, Any, TypeVar, List +from typing_extensions import ParamSpec + +P = ParamSpec('P') +T = TypeVar('T') + +def dec(f: Callable[P, T]) -> Callable[P, List[T]]: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> List[T]: + return [f(*args, **kwargs)] + return wrapper + +@dec +def g(x: int, y: str = '') -> int: ... + +reveal_type(g) # N: Revealed type is "def (x: builtins.int, y: builtins.str =) -> builtins.list[builtins.int*]" +[builtins fixtures/dict.pyi] + +[case testParamSpecArgsAndKwargsTypes] +from typing import Callable, TypeVar, Generic +from typing_extensions import ParamSpec + +P = ParamSpec('P') + +class C(Generic[P]): + def __init__(self, x: Callable[P, None]) -> None: ... + + def m(self, *args: P.args, **kwargs: P.kwargs) -> None: + reveal_type(args) # N: Revealed type is "P.args`1" + reveal_type(kwargs) # N: Revealed type is "P.kwargs`1" +[builtins fixtures/dict.pyi] + +[case testParamSpecSubtypeChecking1] +from typing import Callable, TypeVar, Generic, Any +from typing_extensions import ParamSpec + +P = ParamSpec('P') + +class C(Generic[P]): + def __init__(self, x: Callable[P, None]) -> None: ... + + def m(self, *args: P.args, **kwargs: P.kwargs) -> None: + args = args + kwargs = kwargs + o: object + o = args + o = kwargs + o2: object + args = o2 # E: Incompatible types in assignment (expression has type "object", variable has type "P.args") + kwargs = o2 # E: Incompatible types in assignment (expression has type "object", variable has type "P.kwargs") + a: Any + a = args + a = kwargs + args = kwargs # E: Incompatible types in assignment (expression has type "P.kwargs", variable has type "P.args") + kwargs = args # E: Incompatible types in assignment (expression has type "P.args", variable has type "P.kwargs") + args = a + kwargs = a +[builtins fixtures/dict.pyi] + +[case testParamSpecSubtypeChecking2] +from typing import Callable, Generic +from typing_extensions import ParamSpec + +P = ParamSpec('P') +P2 = ParamSpec('P2') + +class C(Generic[P]): + pass + +def f(c1: C[P], c2: C[P2]) -> None: + c1 = c1 + c2 = c2 + c1 = c2 # E: Incompatible types in assignment (expression has type "C[P2]", variable has type "C[P]") + c2 = c1 # E: Incompatible types in assignment (expression has type "C[P]", variable has type "C[P2]") + +def g(f: Callable[P, None], g: Callable[P2, None]) -> None: + f = f + g = g + f = g # E: Incompatible types in assignment (expression has type "Callable[P2, None]", variable has type "Callable[P, None]") + g = f # E: Incompatible types in assignment (expression has type "Callable[P, None]", variable has type "Callable[P2, None]") +[builtins fixtures/dict.pyi] + +[case testParamSpecJoin] +from typing import Callable, Generic, TypeVar +from typing_extensions import ParamSpec + +P = ParamSpec('P') +P2 = ParamSpec('P2') +P3 = ParamSpec('P3') +T = TypeVar('T') + +def join(x: T, y: T) -> T: ... + +class C(Generic[P, P2]): + def m(self, f: Callable[P, None], g: Callable[P2, None]) -> None: + reveal_type(join(f, f)) # N: Revealed type is "def (*P.args, **P.kwargs)" + reveal_type(join(f, g)) # N: Revealed type is "builtins.function*" + + def m2(self, *args: P.args, **kwargs: P.kwargs) -> None: + reveal_type(join(args, args)) # N: Revealed type is "P.args`1" + reveal_type(join(kwargs, kwargs)) # N: Revealed type is "P.kwargs`1" + reveal_type(join(args, kwargs)) # N: Revealed type is "builtins.object*" + def f(*args2: P2.args, **kwargs2: P2.kwargs) -> None: + reveal_type(join(args, args2)) # N: Revealed type is "builtins.object*" + reveal_type(join(kwargs, kwargs2)) # N: Revealed type is "builtins.object*" + + def m3(self, c: C[P, P3]) -> None: + reveal_type(join(c, c)) # N: Revealed type is "__main__.C*[P`1, P3`-1]" + reveal_type(join(self, c)) # N: Revealed type is "builtins.object*" +[builtins fixtures/dict.pyi] + +[case testParamSpecClassWithAny] +from typing import Callable, Generic, Any +from typing_extensions import ParamSpec + +P = ParamSpec('P') + +class C(Generic[P]): + def __init__(self, x: Callable[P, None]) -> None: ... + + def m(self, *args: P.args, **kwargs: P.kwargs) -> int: + return 1 + +c: C[Any] +reveal_type(c) # N: Revealed type is "__main__.C[Any]" +reveal_type(c.m) # N: Revealed type is "def (*args: Any, **kwargs: Any) -> builtins.int" +c.m(4, 6, y='x') +c = c + +def f() -> None: pass + +c2 = C(f) +c2 = c +c3 = C(f) +c = c3 +[builtins fixtures/dict.pyi]