diff --git a/extensions/mypy_extensions.py b/extensions/mypy_extensions.py index c5442a3e1c00..f4bad5602af3 100644 --- a/extensions/mypy_extensions.py +++ b/extensions/mypy_extensions.py @@ -157,3 +157,10 @@ def __getitem__(self, args): FlexibleAlias = _FlexibleAliasCls() + + +def delegate(base_func, exclude=()): + def decorator(func): + return func + + return decorator diff --git a/mypy/checker.py b/mypy/checker.py index ed3955bc6a3a..99bdaf514ed4 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -13,7 +13,7 @@ SymbolTable, Statement, MypyFile, Var, Expression, Lvalue, Node, OverloadedFuncDef, FuncDef, FuncItem, FuncBase, TypeInfo, ClassDef, Block, AssignmentStmt, NameExpr, MemberExpr, IndexExpr, - TupleExpr, ListExpr, ExpressionStmt, ReturnStmt, IfStmt, + TupleExpr, ListExpr, SetExpr, ExpressionStmt, ReturnStmt, IfStmt, WhileStmt, OperatorAssignmentStmt, WithStmt, AssertStmt, RaiseStmt, TryStmt, ForStmt, DelStmt, CallExpr, IntExpr, StrExpr, UnicodeExpr, OpExpr, UnaryExpr, LambdaExpr, TempNode, SymbolTableNode, @@ -2733,6 +2733,27 @@ def visit_decorator(self, e: Decorator) -> None: self.fail('Single overload definition, multiple required', e) continue dec = self.expr_checker.accept(d) + + if (isinstance(d, CallExpr) + and getattr(d.callee, 'fullname', '') == 'mypy_extensions.delegate' + and d.args + and isinstance(sig, CallableType) # TODO allow Overloaded? + and sig.is_kw_arg): + # TODO how should this combine with other decorators? + delegate_sig = self.expr_checker.accept(d.args[0]) + if not isinstance(delegate_sig, CallableType): + continue # TODO error message? + exclude = [] # type: List[str] + if d.arg_names[1:2] == ['exclude']: + exclude_arg = d.args[1] + if not (isinstance(exclude_arg, (ListExpr, TupleExpr, SetExpr)) + and all(isinstance(ex, StrExpr) + for ex in exclude_arg.items)): + continue # TODO error message? + exclude = [s.value for s in cast(List[StrExpr], exclude_arg.items)] + sig = self._delegated_sig(delegate_sig, sig, exclude) + continue + temp = self.temp_node(sig) fullname = None if isinstance(d, RefExpr): @@ -2751,6 +2772,30 @@ def visit_decorator(self, e: Decorator) -> None: if e.func.info and not e.func.is_dynamic(): self.check_method_override(e) + def _delegated_sig(self, + delegate_sig: CallableType, + sig: CallableType, + exclude: List[str]) -> CallableType: + # TODO: also delegate *args (currently only does **kwargs) + args = [(name, + kind if kind != nodes.ARG_OPT else nodes.ARG_NAMED_OPT, + typ) + for (name, kind, typ) in + zip(delegate_sig.arg_names, + delegate_sig.arg_kinds, + delegate_sig.arg_types) + if kind not in (nodes.ARG_POS, nodes.ARG_STAR) + and name not in sig.arg_names + and name not in exclude] + names, kinds, types = map(list, zip(*args)) + # **kwargs are always last in the signature, so we remove them with [:-1] + sig = sig.copy_modified( + arg_names=sig.arg_names[:-1] + cast(List[Optional[str]], names), + arg_kinds=sig.arg_kinds[:-1] + cast(List[int], kinds), + arg_types=sig.arg_types[:-1] + cast(List[Type], types), + ) + return sig + def check_for_untyped_decorator(self, func: FuncDef, dec_type: Type, diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 10202fe99eab..500876ad7c66 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -43,6 +43,7 @@ 'check-lists.test', 'check-namedtuple.test', 'check-typeddict.test', + 'check-delegate.test', 'check-type-aliases.test', 'check-ignore.test', 'check-type-promotion.test', diff --git a/test-data/unit/check-delegate.test b/test-data/unit/check-delegate.test new file mode 100644 index 000000000000..1e7f42037fde --- /dev/null +++ b/test-data/unit/check-delegate.test @@ -0,0 +1,67 @@ +-- Delegating arguments + +[case testSimpleDelegation] +from mypy_extensions import delegate + +def raw(name: str = 'he', age: int = 42): + return '%s is %s' % (name, age) + +@delegate(raw) +def cooked(**kwargs): + return raw(**kwargs) + +reveal_type(cooked) # E: Revealed type is 'def (*, name: builtins.str =, age: builtins.int =) -> Any' +cooked(x=56) # E: Unexpected keyword argument "x" for "cooked" +[builtins fixtures/dict.pyi] + + +[case testDelegationWithPositionalArg] +from mypy_extensions import delegate + +def raw(foo, name='he', age=42): + return '%s is %s' % (name, age) + +@delegate(raw) +def cooked(foo, bar, **kwargs): + return raw(foo, **kwargs) + +reveal_type(cooked) # E: Revealed type is 'def (foo: Any, bar: Any, *, name: Any =, age: Any =) -> Any' +cooked(3) # E: Too few arguments for "cooked" +cooked(3, 4) +cooked(3, 4, 5) # E: Too many positional arguments for "cooked" +cooked(3, 4, name='bob') +cooked(3, 4, x='bob') # E: Unexpected keyword argument "x" for "cooked" +[builtins fixtures/dict.pyi] + + +[case testDelegationWithKeywordOnlyArg] +from mypy_extensions import delegate + +def raw(*, name, age): + return '%s is %s' % (name, age) + +@delegate(raw) +def cooked(foo, bar, **kwargs): + return raw(foo, **kwargs) + +reveal_type(cooked) # E: Revealed type is 'def (foo: Any, bar: Any, *, name: Any, age: Any) -> Any' +cooked(3, 4, name='bob', age=34) +cooked(3, 4, name='bob') # E: Missing named argument "age" for "cooked" +cooked(3, 4, x='bob') # E: Unexpected keyword argument "x" for "cooked" +[builtins fixtures/dict.pyi] + + +[case testDelegationWithExclude] +from mypy_extensions import delegate + +def raw(name='he', age=42): + return '%s is %s' % (name, age) + +@delegate(raw, exclude=['name']) +def cooked(**kwargs): + return raw(name='bob', **kwargs) + +reveal_type(cooked) # E: Revealed type is 'def (*, age: Any =) -> Any' +cooked(age=32) +cooked(name='me') # E: Unexpected keyword argument "name" for "cooked" +[builtins fixtures/dict.pyi] diff --git a/test-data/unit/lib-stub/mypy_extensions.pyi b/test-data/unit/lib-stub/mypy_extensions.pyi index 791ff9b2d7ea..1bd77d00b0ff 100644 --- a/test-data/unit/lib-stub/mypy_extensions.pyi +++ b/test-data/unit/lib-stub/mypy_extensions.pyi @@ -1,5 +1,5 @@ # NOTE: Requires fixtures/dict.pyi -from typing import Dict, Type, TypeVar, Optional, Any, Generic +from typing import Dict, Type, TypeVar, Optional, Any, Generic, Callable, List _T = TypeVar('_T') _U = TypeVar('_U') @@ -27,3 +27,6 @@ def trait(cls: Any) -> Any: ... class NoReturn: pass class FlexibleAlias(Generic[_T, _U]): ... + + +def delegate(base_func: Callable, exclude: List[str] = ()) -> Callable: ...