Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2219,8 +2219,12 @@ def visit_decorator(self, e: Decorator) -> None:
continue
dec = self.expr_checker.accept(d)
temp = self.temp_node(sig)
fullname = None
if isinstance(d, RefExpr):
fullname = d.fullname
sig, t2 = self.expr_checker.check_call(dec, [temp],
[nodes.ARG_POS], e)
[nodes.ARG_POS], e,
callable_name=fullname)
sig = cast(FunctionLike, sig)
sig = set_callable_name(sig, e.func)
e.var.type = sig
Expand Down
41 changes: 37 additions & 4 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from mypy.util import split_module_names
from mypy.typevars import fill_typevars
from mypy.visitor import ExpressionVisitor
from mypy.funcplugins import get_function_plugin_callbacks, PluginCallback

from mypy import experiments

Expand Down Expand Up @@ -103,6 +104,7 @@ class ExpressionChecker(ExpressionVisitor[Type]):
type_context = None # type: List[Optional[Type]]

strfrm_checker = None # type: StringFormatterChecker
function_plugins = None # type: Dict[str, PluginCallback]

def __init__(self,
chk: 'mypy.checker.TypeChecker',
Expand All @@ -112,6 +114,7 @@ def __init__(self,
self.msg = msg
self.type_context = [None]
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg)
self.function_plugins = get_function_plugin_callbacks(self.chk.options.python_version)

def visit_name_expr(self, e: NameExpr) -> Type:
"""Type check a name expression.
Expand Down Expand Up @@ -198,7 +201,11 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
isinstance(callee_type, CallableType)
and callee_type.implicit):
return self.msg.untyped_function_call(callee_type, e)
ret_type = self.check_call_expr_with_callee_type(callee_type, e)
if not isinstance(e.callee, RefExpr):
fullname = None
else:
fullname = e.callee.fullname
ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname)
if isinstance(ret_type, UninhabitedType):
self.chk.binder.unreachable()
if not allow_none_return and isinstance(ret_type, NoneTyp):
Expand Down Expand Up @@ -330,21 +337,42 @@ def try_infer_partial_type(self, e: CallExpr) -> None:
list(full_item_types))
del partial_types[var]

def apply_function_plugin(self,
arg_types: List[Type],
inferred_ret_type: Type,
arg_kinds: List[int],
formal_to_actual: List[List[int]],
args: List[Expression],
num_formals: int,
fullname: Optional[str]) -> Type:
"""Use special case logic to infer the return type for of a particular named function.

Return the inferred return type.
"""
formal_arg_types = [None] * num_formals # type: List[Optional[Type]]
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_types[formal] = arg_types[actual]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't there some edge cases where map_actuals_to_formals() returns overlapping mappings? (IIRC related to *args and worse.)

return self.function_plugins[fullname](
formal_arg_types, inferred_ret_type, args, self.chk.named_generic_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the plugins have access to the Errors object too?


def check_call_expr_with_callee_type(self, callee_type: Type,
e: CallExpr) -> Type:
e: CallExpr, callable_name: Optional[str]) -> Type:
"""Type check call expression.

The given callee type overrides the type of the callee
expression.
"""
return self.check_call(callee_type, e.args, e.arg_kinds, e,
e.arg_names, callable_node=e.callee)[0]
e.arg_names, callable_node=e.callee,
callable_name=callable_name)[0]

def check_call(self, callee: Type, args: List[Expression],
arg_kinds: List[int], context: Context,
arg_names: List[str] = None,
callable_node: Expression = None,
arg_messages: MessageBuilder = None) -> Tuple[Type, Type]:
arg_messages: MessageBuilder = None,
callable_name: Optional[str] = None) -> Tuple[Type, Type]:
"""Type check a call.

Also infer type arguments if the callee is a generic function.
Expand Down Expand Up @@ -406,6 +434,11 @@ def check_call(self, callee: Type, args: List[Expression],
if callable_node:
# Store the inferred callable type.
self.chk.store_type(callable_node, callee)
if callable_name in self.function_plugins:
ret_type = self.apply_function_plugin(
arg_types, callee.ret_type, arg_kinds, formal_to_actual,
args, len(callee.arg_types), callable_name)
callee = callee.copy_modified(ret_type=ret_type)
return callee.ret_type, callee
elif isinstance(callee, Overloaded):
# Type check arguments in empty context. They will be checked again
Expand Down
71 changes: 71 additions & 0 deletions mypy/funcplugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Plugins that implement special type checking rules for individual functions.
The plugins infer better types for tricky functions such as "open".
"""

from typing import Tuple, Dict, Callable, List

from mypy.nodes import Expression, StrExpr
from mypy.types import Type, Instance, CallableType


PluginCallback = Callable[[List[Type],
Type,
List[Expression],
Callable[[str, List[Type]], Type]],
Type]


def get_function_plugin_callbacks(python_version: Tuple[int, int]) -> Dict[str, PluginCallback]:
"""Return all available function plugins for a given Python version."""
if python_version[0] == 3:
return {
'builtins.open': open_callback,
'contextlib.contextmanager': contextmanager_callback,
}
else:
return {
'contextlib.contextmanager': contextmanager_callback,
}


def open_callback(
arg_types: List[Type],
inferred_return_type: Type,
args: List[Expression],
named_generic_type: Callable[[str, List[Type]], Type]) -> Type:
"""Infer a better return type for 'open'.
Infer IO[str] or IO[bytes] as the return value if the mode argument is not
given or is a literal.
"""
mode = None
if arg_types[1] is None:
mode = 'r'
elif isinstance(args[1], StrExpr):
mode = args[1].value
if mode is not None:
assert isinstance(inferred_return_type, Instance)
if 'b' in mode:
arg = named_generic_type('builtins.bytes', [])
else:
arg = named_generic_type('builtins.str', [])
return Instance(inferred_return_type.type, [arg])
return inferred_return_type


def contextmanager_callback(
arg_types: List[Type],
inferred_return_type: Type,
args: List[Expression],
named_generic_type: Callable[[str, List[Type]], Type]) -> Type:
"""Infer a better return type for 'contextlib.contextmanager'."""
arg_type = arg_types[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would fail if there were no arguments, right? (Which would give some other error but might still get here?)

if isinstance(arg_type, CallableType) and isinstance(inferred_return_type, CallableType):
# The stub signature doesn't preserve information about arguments so
# add them back here.
return inferred_return_type.copy_modified(
arg_types=arg_type.arg_types,
arg_kinds=arg_type.arg_kinds,
arg_names=arg_type.arg_names)
return inferred_return_type
39 changes: 38 additions & 1 deletion test-data/unit/pythoneval.test
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,20 @@ f.write('x')
f.write(b'x')
f.foobar()
[out]
_program.py:4: error: IO[Any] has no attribute "foobar"
_program.py:3: error: Argument 1 to "write" of "IO" has incompatible type "bytes"; expected "str"
_program.py:4: error: IO[str] has no attribute "foobar"

[case testOpenReturnTypeInference]
reveal_type(open('x'))
reveal_type(open('x', 'r'))
reveal_type(open('x', 'rb'))
mode = 'rb'
reveal_type(open('x', mode))
[out]
_program.py:1: error: Revealed type is 'typing.IO[builtins.str]'
_program.py:2: error: Revealed type is 'typing.IO[builtins.str]'
_program.py:3: error: Revealed type is 'typing.IO[builtins.bytes]'
_program.py:5: error: Revealed type is 'typing.IO[Any]'

[case testGenericPatterns]
from typing import Pattern
Expand Down Expand Up @@ -1286,3 +1299,27 @@ a[1] = 2, 'y'
a[:] = [('z', 3)]
[out]
_program.py:4: error: Incompatible types in assignment (expression has type "Tuple[int, str]", target has type "Tuple[str, int]")

[case testContextManager]
import contextlib
from contextlib import contextmanager
from typing import Iterator

@contextmanager
def f(x: int) -> Iterator[str]:
yield 'foo'

@contextlib.contextmanager
def g(*x: str) -> Iterator[int]:
yield 1

reveal_type(f)
reveal_type(g)

with f('') as s:
reveal_type(s)
[out]
_program.py:13: error: Revealed type is 'def (x: builtins.int) -> contextlib.GeneratorContextManager[builtins.str*]'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that that class is misnamed in typeshed, it should be _GeneratorContextManager (to match what it's called at runtime). I also don't understand what its __call__ method is for (contextlib doesn't seem to have reference docs, and the source has few clues).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __call__ is so that @contextmanager-decorated functions can also be used as decorators themselves (executing the decorated function within the context). Nick Coghlan has said that he considers this feature a design mistake in contextlib.

_program.py:14: error: Revealed type is 'def (*x: builtins.str) -> contextlib.GeneratorContextManager[builtins.int*]'
_program.py:16: error: Argument 1 to "f" has incompatible type "str"; expected "int"
_program.py:17: error: Revealed type is 'builtins.str*'