Skip to content

Support for overloaded functions in stubgenc generated by pybind11 #5975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jan 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 50 additions & 50 deletions mypy/stubgenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from types import ModuleType

from mypy.stubutil import (
is_c_module, write_header, infer_sig_from_docstring,
infer_prop_type_from_docstring
is_c_module, write_header, infer_sig_from_docstring, infer_prop_type_from_docstring,
ArgSig, infer_arg_sig_from_docstring, FunctionSig
)


Expand Down Expand Up @@ -123,51 +123,46 @@ def generate_c_function_stub(module: ModuleType,
) -> None:
ret_type = 'None' if name == '__init__' and class_name else 'Any'

if self_var:
self_arg = '%s, ' % self_var
else:
self_arg = ''
if (name in ('__new__', '__init__') and name not in sigs and class_name and
class_name in class_sigs):
sig = class_sigs[class_name]
inferred = [FunctionSig(name=name,
args=infer_arg_sig_from_docstring(class_sigs[class_name]),
ret_type=ret_type)] # type: Optional[List[FunctionSig]]
else:
docstr = getattr(obj, '__doc__', None)
inferred = infer_sig_from_docstring(docstr, name)
if inferred:
sig, ret_type = inferred
else:
if not inferred:
if class_name and name not in sigs:
sig = infer_method_sig(name)
inferred = [FunctionSig(name, args=infer_method_sig(name), ret_type=ret_type)]
else:
sig = sigs.get(name, '(*args, **kwargs)')
# strip away parenthesis
sig = sig[1:-1]
if sig:
if self_var:
# remove annotation on self from signature if present
groups = sig.split(',', 1)
if groups[0] == self_var or groups[0].startswith(self_var + ':'):
self_arg = ''
sig = '{},{}'.format(self_var, groups[1]) if len(groups) > 1 else self_var
else:
self_arg = self_arg.replace(', ', '')

if sig:
sig_types = []
# convert signature in form of "self: TestClass, arg0: str" to
# list [[self, TestClass], [arg0, str]]
for arg in sig.split(','):
arg_type = arg.split(':', 1)
if len(arg_type) == 1:
# there is no type provided in docstring
sig_types.append(arg_type[0].strip())
else:
arg_type_name = strip_or_import(arg_type[1].strip(), module, imports)
sig_types.append('%s: %s' % (arg_type[0].strip(), arg_type_name))
sig = ", ".join(sig_types)
inferred = [FunctionSig(name=name,
args=infer_arg_sig_from_docstring(
sigs.get(name, '(*args, **kwargs)')),
ret_type=ret_type)]

is_overloaded = len(inferred) > 1 if inferred else False
if is_overloaded:
imports.append('from typing import overload')
if inferred:
for signature in inferred:
sig = []
for arg in signature.args:
if arg.name == self_var or not arg.type:
# no type
sig.append(arg.name)
else:
# type info
sig.append('{}: {}'.format(arg.name, strip_or_import(arg.type,
module,
imports)))

ret_type = strip_or_import(ret_type, module, imports)
output.append('def %s(%s%s) -> %s: ...' % (name, self_arg, sig, ret_type))
if is_overloaded:
output.append('@overload')
output.append('def {function}({args}) -> {ret}: ...'.format(
function=name,
args=", ".join(sig),
ret=strip_or_import(signature.ret_type, module, imports)
))


def strip_or_import(typ: str, module: ModuleType, imports: List[str]) -> str:
Expand Down Expand Up @@ -307,29 +302,34 @@ def is_skipped_attribute(attr: str) -> bool:
'__weakref__') # For pickling


def infer_method_sig(name: str) -> str:
def infer_method_sig(name: str) -> List[ArgSig]:
if name.startswith('__') and name.endswith('__'):
name = name[2:-2]
if name in ('hash', 'iter', 'next', 'sizeof', 'copy', 'deepcopy', 'reduce', 'getinitargs',
'int', 'float', 'trunc', 'complex', 'bool'):
return '()'
return []
if name == 'getitem':
return '(index)'
return [ArgSig(name='index')]
if name == 'setitem':
return '(index, object)'
return [ArgSig(name='index'),
ArgSig(name='object')]
if name in ('delattr', 'getattr'):
return '(name)'
return [ArgSig(name='name')]
if name == 'setattr':
return '(name, value)'
return [ArgSig(name='name'),
ArgSig(name='value')]
if name == 'getstate':
return '()'
return []
if name == 'setstate':
return '(state)'
return [ArgSig(name='state')]
if name in ('eq', 'ne', 'lt', 'le', 'gt', 'ge',
'add', 'radd', 'sub', 'rsub', 'mul', 'rmul',
'mod', 'rmod', 'floordiv', 'rfloordiv', 'truediv', 'rtruediv',
'divmod', 'rdivmod', 'pow', 'rpow'):
return '(other)'
return [ArgSig(name='other')]
if name in ('neg', 'pos'):
return '()'
return '(*args, **kwargs)'
return []
return [
ArgSig(name='*args'),
ArgSig(name='**kwargs')
]
201 changes: 176 additions & 25 deletions mypy/stubutil.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,164 @@
import contextlib
import io
import re
import sys
import os
import tokenize

from typing import Optional, Tuple, Sequence, MutableSequence, List, MutableMapping, IO
from typing import (Optional, Tuple, Sequence, MutableSequence, List, MutableMapping, IO,
NamedTuple, Any)
from types import ModuleType

MYPY = False
if MYPY:
from typing_extensions import Final

# Type Alias for Signatures
Sig = Tuple[str, str]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still used somewhere? If yes, maybe we should switch to ArgSig also in the remaining places (if there are only few of them)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still used in:

  • find_unique_signatures(sigs: Sequence[Sig]) -> List[Sig]:
  • parse_all_signatures(lines: Sequence[str]) -> Tuple[List[Sig], List[Sig]]



class ArgSig:
def __init__(self, name: str, type: Optional[str] = None, default: bool = False):
self.name = name
self.type = type
self.default = default

def __repr__(self) -> str:
return "ArgSig(name={}, type={}, default={})".format(repr(self.name), repr(self.type),
repr(self.default))

def __eq__(self, other: Any) -> bool:
if isinstance(other, ArgSig):
return (self.name == other.name and self.type == other.type and
self.default == other.default)
return False


FunctionSig = NamedTuple('FunctionSig', [
('name', str),
('args', List[ArgSig]),
('ret_type', str)
])


STATE_INIT = 1 # type: Final
STATE_FUNCTION_NAME = 2 # type: Final
STATE_ARGUMENT_LIST = 3 # type: Final
STATE_ARGUMENT_TYPE = 4 # type: Final
STATE_ARGUMENT_DEFAULT = 5 # type: Final
STATE_RETURN_VALUE = 6 # type: Final
STATE_OPEN_BRACKET = 7 # type: Final


class DocStringParser:
def __init__(self, function_name: str) -> None:
self.function_name = function_name
self.state = [STATE_INIT]
self.accumulator = ""
self.arg_type = None # type: Optional[str]
self.arg_name = ""
self.arg_default = None # type: Optional[str]
self.ret_type = "Any"
self.found = False
self.args = [] # type: List[ArgSig]
self.signatures = [] # type: List[FunctionSig]

def add_token(self, token: tokenize.TokenInfo) -> None:
if (token.type == tokenize.NAME and token.string == self.function_name and
self.state[-1] == STATE_INIT):
self.state.append(STATE_FUNCTION_NAME)

elif (token.type == tokenize.OP and token.string == '(' and
self.state[-1] == STATE_FUNCTION_NAME):
self.state.pop()
self.accumulator = ""
self.found = True
self.state.append(STATE_ARGUMENT_LIST)

elif self.state[-1] == STATE_FUNCTION_NAME:
# reset state, function name not followed by '('
self.state.pop()

elif (token.type == tokenize.OP and token.string in ('[', '(', '{') and
self.state[-1] != STATE_INIT):
self.accumulator += token.string
self.state.append(STATE_OPEN_BRACKET)

elif (token.type == tokenize.OP and token.string in (']', ')', '}') and
self.state[-1] == STATE_OPEN_BRACKET):
self.accumulator += token.string
self.state.pop()

elif (token.type == tokenize.OP and token.string == ':' and
self.state[-1] == STATE_ARGUMENT_LIST):
self.arg_name = self.accumulator
self.accumulator = ""
self.state.append(STATE_ARGUMENT_TYPE)

elif (token.type == tokenize.OP and token.string == '=' and
self.state[-1] in (STATE_ARGUMENT_LIST, STATE_ARGUMENT_TYPE)):
if self.state[-1] == STATE_ARGUMENT_TYPE:
self.arg_type = self.accumulator
self.state.pop()
else:
self.arg_name = self.accumulator
self.accumulator = ""
self.state.append(STATE_ARGUMENT_DEFAULT)

elif (token.type == tokenize.OP and token.string in (',', ')') and
self.state[-1] in (STATE_ARGUMENT_LIST, STATE_ARGUMENT_DEFAULT,
STATE_ARGUMENT_TYPE)):
if self.state[-1] == STATE_ARGUMENT_DEFAULT:
self.arg_default = self.accumulator
self.state.pop()
elif self.state[-1] == STATE_ARGUMENT_TYPE:
self.arg_type = self.accumulator
self.state.pop()
elif self.state[-1] == STATE_ARGUMENT_LIST:
self.arg_name = self.accumulator

if token.string == ')':
self.state.pop()
self.args.append(ArgSig(name=self.arg_name, type=self.arg_type,
default=bool(self.arg_default)))
self.arg_name = ""
self.arg_type = None
self.arg_default = None
self.accumulator = ""

elif token.type == tokenize.OP and token.string == '->' and self.state[-1] == STATE_INIT:
self.accumulator = ""
self.state.append(STATE_RETURN_VALUE)

# ENDMAKER is necessary for python 3.4 and 3.5
elif (token.type in (tokenize.NEWLINE, tokenize.ENDMARKER) and
self.state[-1] in (STATE_INIT, STATE_RETURN_VALUE)):
if self.state[-1] == STATE_RETURN_VALUE:
self.ret_type = self.accumulator
self.accumulator = ""
self.state.pop()

if self.found:
self.signatures.append(FunctionSig(name=self.function_name, args=self.args,
ret_type=self.ret_type))
self.found = False
self.args = []
self.ret_type = 'Any'
# leave state as INIT
else:
self.accumulator += token.string

def get_signatures(self) -> List[FunctionSig]:
def has_arg(name: str, signature: FunctionSig) -> bool:
return any(x.name == name for x in signature.args)

def args_kwargs(signature: FunctionSig) -> bool:
return has_arg('*args', signature) and has_arg('**kwargs', signature)

# Move functions with (*args, **kwargs) in their signature to last place
return list(sorted(self.signatures, key=lambda x: 1 if args_kwargs(x) else 0))


def parse_signature(sig: str) -> Optional[Tuple[str,
List[str],
List[str]]]:
Expand Down Expand Up @@ -106,32 +255,34 @@ def write_header(file: IO[str], module_name: Optional[str] = None,
'# NOTE: This dynamically typed stub was automatically generated by stubgen.\n\n')


def infer_sig_from_docstring(docstr: str, name: str) -> Optional[Tuple[str, str]]:
def infer_sig_from_docstring(docstr: str, name: str) -> Optional[List[FunctionSig]]:
"""Concert function signature to list of TypedFunctionSig

Looks for function signatures of function in docstring. Returns empty list, when no signature
is found, one signature in typical case, multiple signatures, if docstring specifies multiple
signatures for overload functions.

Arguments:
* docstr: docstring
* name: name of function for which signatures are to be found
"""
if not docstr:
return None
docstr = docstr.lstrip()
# look for function signature, which is any string of the format
# <function_name>(<signature>) -> <return type>
# or perhaps without the return type

# in the signature, we allow the following characters:
# colon/equal: to match default values, like "a: int=1"
# comma/space/brackets: for type hints like "a: Tuple[int, float]"
# dot: for classes annotating using full path, like "a: foo.bar.baz"
# to capture return type,
sig_str = r'\([a-zA-Z0-9_=:, \[\]\.]*\)'
sig_match = r'%s(%s)' % (name, sig_str)
# first, try to capture return type; we just match until end of line
m = re.match(sig_match + ' -> ([a-zA-Z].*)$', docstr, re.MULTILINE)
if m:
# strip potential white spaces at the right of return type
return m.group(1), m.group(2).rstrip()

# try to not match return type
m = re.match(sig_match, docstr)
if m:
return m.group(1), 'Any'
return None

state = DocStringParser(name)
with contextlib.suppress(tokenize.TokenError):
for token in tokenize.tokenize(io.BytesIO(docstr.encode('utf-8')).readline):
state.add_token(token)
return state.get_signatures()


def infer_arg_sig_from_docstring(docstr: str) -> List[ArgSig]:
"""Convert signature in form of "(self: TestClass, arg0: str='ada')" to List[TypedArgList]."""
ret = infer_sig_from_docstring("stub" + docstr, "stub")
if ret:
return ret[0].args

return []


def infer_prop_type_from_docstring(docstr: str) -> Optional[str]:
Expand Down
Loading