-
-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Add type checking plugin support for functions #3299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,7 @@ | |
| from mypy.util import split_module_names | ||
| from mypy.typevars import fill_typevars | ||
| from mypy.visitor import ExpressionVisitor | ||
| from mypy.funcplugins import get_function_plugin_callbacks, PluginCallback | ||
|
|
||
| from mypy import experiments | ||
|
|
||
|
|
@@ -103,6 +104,7 @@ class ExpressionChecker(ExpressionVisitor[Type]): | |
| type_context = None # type: List[Optional[Type]] | ||
|
|
||
| strfrm_checker = None # type: StringFormatterChecker | ||
| function_plugins = None # type: Dict[str, PluginCallback] | ||
|
|
||
| def __init__(self, | ||
| chk: 'mypy.checker.TypeChecker', | ||
|
|
@@ -112,6 +114,7 @@ def __init__(self, | |
| self.msg = msg | ||
| self.type_context = [None] | ||
| self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg) | ||
| self.function_plugins = get_function_plugin_callbacks(self.chk.options.python_version) | ||
|
|
||
| def visit_name_expr(self, e: NameExpr) -> Type: | ||
| """Type check a name expression. | ||
|
|
@@ -198,7 +201,11 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type: | |
| isinstance(callee_type, CallableType) | ||
| and callee_type.implicit): | ||
| return self.msg.untyped_function_call(callee_type, e) | ||
| ret_type = self.check_call_expr_with_callee_type(callee_type, e) | ||
| if not isinstance(e.callee, RefExpr): | ||
| fullname = None | ||
| else: | ||
| fullname = e.callee.fullname | ||
| ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname) | ||
| if isinstance(ret_type, UninhabitedType): | ||
| self.chk.binder.unreachable() | ||
| if not allow_none_return and isinstance(ret_type, NoneTyp): | ||
|
|
@@ -330,21 +337,42 @@ def try_infer_partial_type(self, e: CallExpr) -> None: | |
| list(full_item_types)) | ||
| del partial_types[var] | ||
|
|
||
| def apply_function_plugin(self, | ||
| arg_types: List[Type], | ||
| inferred_ret_type: Type, | ||
| arg_kinds: List[int], | ||
| formal_to_actual: List[List[int]], | ||
| args: List[Expression], | ||
| num_formals: int, | ||
| fullname: Optional[str]) -> Type: | ||
| """Use special case logic to infer the return type for of a particular named function. | ||
|
|
||
| Return the inferred return type. | ||
| """ | ||
| formal_arg_types = [None] * num_formals # type: List[Optional[Type]] | ||
| for formal, actuals in enumerate(formal_to_actual): | ||
| for actual in actuals: | ||
| formal_arg_types[formal] = arg_types[actual] | ||
| return self.function_plugins[fullname]( | ||
| formal_arg_types, inferred_ret_type, args, self.chk.named_generic_type) | ||
|
||
|
|
||
| def check_call_expr_with_callee_type(self, callee_type: Type, | ||
| e: CallExpr) -> Type: | ||
| e: CallExpr, callable_name: Optional[str]) -> Type: | ||
| """Type check call expression. | ||
|
|
||
| The given callee type overrides the type of the callee | ||
| expression. | ||
| """ | ||
| return self.check_call(callee_type, e.args, e.arg_kinds, e, | ||
| e.arg_names, callable_node=e.callee)[0] | ||
| e.arg_names, callable_node=e.callee, | ||
| callable_name=callable_name)[0] | ||
|
|
||
| def check_call(self, callee: Type, args: List[Expression], | ||
| arg_kinds: List[int], context: Context, | ||
| arg_names: List[str] = None, | ||
| callable_node: Expression = None, | ||
| arg_messages: MessageBuilder = None) -> Tuple[Type, Type]: | ||
| arg_messages: MessageBuilder = None, | ||
| callable_name: Optional[str] = None) -> Tuple[Type, Type]: | ||
| """Type check a call. | ||
|
|
||
| Also infer type arguments if the callee is a generic function. | ||
|
|
@@ -406,6 +434,11 @@ def check_call(self, callee: Type, args: List[Expression], | |
| if callable_node: | ||
| # Store the inferred callable type. | ||
| self.chk.store_type(callable_node, callee) | ||
| if callable_name in self.function_plugins: | ||
| ret_type = self.apply_function_plugin( | ||
| arg_types, callee.ret_type, arg_kinds, formal_to_actual, | ||
| args, len(callee.arg_types), callable_name) | ||
| callee = callee.copy_modified(ret_type=ret_type) | ||
| return callee.ret_type, callee | ||
| elif isinstance(callee, Overloaded): | ||
| # Type check arguments in empty context. They will be checked again | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| """Plugins that implement special type checking rules for individual functions. | ||
| The plugins infer better types for tricky functions such as "open". | ||
| """ | ||
|
|
||
| from typing import Tuple, Dict, Callable, List | ||
|
|
||
| from mypy.nodes import Expression, StrExpr | ||
| from mypy.types import Type, Instance, CallableType | ||
|
|
||
|
|
||
| PluginCallback = Callable[[List[Type], | ||
| Type, | ||
| List[Expression], | ||
| Callable[[str, List[Type]], Type]], | ||
| Type] | ||
|
|
||
|
|
||
| def get_function_plugin_callbacks(python_version: Tuple[int, int]) -> Dict[str, PluginCallback]: | ||
| """Return all available function plugins for a given Python version.""" | ||
| if python_version[0] == 3: | ||
| return { | ||
| 'builtins.open': open_callback, | ||
| 'contextlib.contextmanager': contextmanager_callback, | ||
| } | ||
| else: | ||
| return { | ||
| 'contextlib.contextmanager': contextmanager_callback, | ||
| } | ||
|
|
||
|
|
||
| def open_callback( | ||
| arg_types: List[Type], | ||
| inferred_return_type: Type, | ||
| args: List[Expression], | ||
| named_generic_type: Callable[[str, List[Type]], Type]) -> Type: | ||
| """Infer a better return type for 'open'. | ||
| Infer IO[str] or IO[bytes] as the return value if the mode argument is not | ||
| given or is a literal. | ||
| """ | ||
| mode = None | ||
| if arg_types[1] is None: | ||
| mode = 'r' | ||
| elif isinstance(args[1], StrExpr): | ||
| mode = args[1].value | ||
| if mode is not None: | ||
| assert isinstance(inferred_return_type, Instance) | ||
| if 'b' in mode: | ||
| arg = named_generic_type('builtins.bytes', []) | ||
| else: | ||
| arg = named_generic_type('builtins.str', []) | ||
| return Instance(inferred_return_type.type, [arg]) | ||
| return inferred_return_type | ||
|
|
||
|
|
||
| def contextmanager_callback( | ||
| arg_types: List[Type], | ||
| inferred_return_type: Type, | ||
| args: List[Expression], | ||
| named_generic_type: Callable[[str, List[Type]], Type]) -> Type: | ||
| """Infer a better return type for 'contextlib.contextmanager'.""" | ||
| arg_type = arg_types[0] | ||
|
||
| if isinstance(arg_type, CallableType) and isinstance(inferred_return_type, CallableType): | ||
| # The stub signature doesn't preserve information about arguments so | ||
| # add them back here. | ||
| return inferred_return_type.copy_modified( | ||
| arg_types=arg_type.arg_types, | ||
| arg_kinds=arg_type.arg_kinds, | ||
| arg_names=arg_type.arg_names) | ||
| return inferred_return_type | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -399,7 +399,20 @@ f.write('x') | |
| f.write(b'x') | ||
| f.foobar() | ||
| [out] | ||
| _program.py:4: error: IO[Any] has no attribute "foobar" | ||
| _program.py:3: error: Argument 1 to "write" of "IO" has incompatible type "bytes"; expected "str" | ||
| _program.py:4: error: IO[str] has no attribute "foobar" | ||
|
|
||
| [case testOpenReturnTypeInference] | ||
| reveal_type(open('x')) | ||
| reveal_type(open('x', 'r')) | ||
| reveal_type(open('x', 'rb')) | ||
| mode = 'rb' | ||
| reveal_type(open('x', mode)) | ||
| [out] | ||
| _program.py:1: error: Revealed type is 'typing.IO[builtins.str]' | ||
| _program.py:2: error: Revealed type is 'typing.IO[builtins.str]' | ||
| _program.py:3: error: Revealed type is 'typing.IO[builtins.bytes]' | ||
| _program.py:5: error: Revealed type is 'typing.IO[Any]' | ||
|
|
||
| [case testGenericPatterns] | ||
| from typing import Pattern | ||
|
|
@@ -1286,3 +1299,27 @@ a[1] = 2, 'y' | |
| a[:] = [('z', 3)] | ||
| [out] | ||
| _program.py:4: error: Incompatible types in assignment (expression has type "Tuple[int, str]", target has type "Tuple[str, int]") | ||
|
|
||
| [case testContextManager] | ||
| import contextlib | ||
| from contextlib import contextmanager | ||
| from typing import Iterator | ||
|
|
||
| @contextmanager | ||
| def f(x: int) -> Iterator[str]: | ||
| yield 'foo' | ||
|
|
||
| @contextlib.contextmanager | ||
| def g(*x: str) -> Iterator[int]: | ||
| yield 1 | ||
|
|
||
| reveal_type(f) | ||
| reveal_type(g) | ||
|
|
||
| with f('') as s: | ||
| reveal_type(s) | ||
| [out] | ||
| _program.py:13: error: Revealed type is 'def (x: builtins.int) -> contextlib.GeneratorContextManager[builtins.str*]' | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just realized that that class is misnamed in typeshed, it should be
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| _program.py:14: error: Revealed type is 'def (*x: builtins.str) -> contextlib.GeneratorContextManager[builtins.int*]' | ||
| _program.py:16: error: Argument 1 to "f" has incompatible type "str"; expected "int" | ||
| _program.py:17: error: Revealed type is 'builtins.str*' | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't there some edge cases where map_actuals_to_formals() returns overlapping mappings? (IIRC related to
*argsand worse.)