Skip to content

Commit 2296af8

Browse files
pkerichanggvanrossum
authored andcommitted
Improve stub generation so it works better with pybind11 wrappers (#5814)
1. Modify walk_packages() so when --recursive option is set, submodules in C extensions are also visited properly. 2. use inspect package to implement is_c_function(), is_c_method(), is_c_classmethod(), and is_c_type() so it is more accurate. The inspect methods are OR'ed with the old check to prevent breaking old code. 3. Improve infer_sig_from_docstring() so it supports Python-annotation style type hints (like a: int) in docstrings. 4. Misc. cleanup.
1 parent e09e727 commit 2296af8

File tree

4 files changed

+142
-35
lines changed

4 files changed

+142
-35
lines changed

mypy/stubgen.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import os
4242
import os.path
4343
import pkgutil
44+
import inspect
4445
import subprocess
4546
import sys
4647
import textwrap
@@ -211,7 +212,6 @@ def generate_stub(path: str,
211212
pyversion: Tuple[int, int] = defaults.PYTHON3_VERSION,
212213
include_private: bool = False
213214
) -> None:
214-
215215
with open(path, 'rb') as f:
216216
data = f.read()
217217
source = mypy.util.decode_python_encoding(data, pyversion)
@@ -254,7 +254,7 @@ def __init__(self, stubgen: 'StubGenerator') -> None:
254254
super().__init__()
255255
self.stubgen = stubgen
256256

257-
def visit_unbound_type(self, t: UnboundType)-> str:
257+
def visit_unbound_type(self, t: UnboundType) -> str:
258258
s = t.name
259259
base = s.split('.')[0]
260260
self.stubgen.import_tracker.require_name(base)
@@ -593,7 +593,7 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
593593
if init:
594594
found = True
595595
if not sep and not self._indent and \
596-
self._state not in (EMPTY, VAR):
596+
self._state not in (EMPTY, VAR):
597597
init = '\n' + init
598598
sep = True
599599
self.add(init)
@@ -795,7 +795,7 @@ def get_str_type_of_node(self, rvalue: Expression,
795795
if isinstance(rvalue, NameExpr) and rvalue.name in ('True', 'False'):
796796
return 'bool'
797797
if can_infer_optional and \
798-
isinstance(rvalue, NameExpr) and rvalue.name == 'None':
798+
isinstance(rvalue, NameExpr) and rvalue.name == 'None':
799799
self.add_typing_import('Optional')
800800
self.add_typing_import('Any')
801801
return 'Optional[Any]'
@@ -850,7 +850,6 @@ def visit_return_stmt(self, o: ReturnStmt) -> None:
850850

851851

852852
def has_return_statement(fdef: FuncBase) -> bool:
853-
854853
seeker = ReturnSeeker()
855854
fdef.accept(seeker)
856855
return seeker.found
@@ -866,17 +865,35 @@ def get_qualified_name(o: Expression) -> str:
866865

867866

868867
def walk_packages(packages: List[str]) -> Iterator[str]:
868+
"""Iterates through all packages and sub-packages in the given list.
869+
870+
Python packages have a __path__ attribute defined, which pkgutil uses to determine
871+
the package hierarchy. However, packages in C extensions do not have this attribute,
872+
so we have to roll out our own.
873+
"""
869874
for package_name in packages:
870875
package = importlib.import_module(package_name)
871876
yield package.__name__
877+
# get the path of the object (needed by pkgutil)
872878
path = getattr(package, '__path__', None)
873879
if path is None:
880+
# object has no path; this means it's either a module inside a package
881+
# (and thus no sub-packages), or it could be a C extension package.
882+
if is_c_module(package):
883+
# This is a C extension module, now get the list of all sub-packages
884+
# using the inspect module
885+
subpackages = [package.__name__ + "." + name
886+
for name, val in inspect.getmembers(package)
887+
if inspect.ismodule(val)]
888+
# recursively iterate through the subpackages
889+
for submodule in walk_packages(subpackages):
890+
yield submodule
874891
# It's a module inside a package. There's nothing else to walk/yield.
875-
continue
876-
for importer, qualified_name, ispkg in pkgutil.walk_packages(path,
877-
prefix=package.__name__ + ".",
878-
onerror=lambda r: None):
879-
yield qualified_name
892+
else:
893+
all_packages = pkgutil.walk_packages(path, prefix=package.__name__ + ".",
894+
onerror=lambda r: None)
895+
for importer, qualified_name, ispkg in all_packages:
896+
yield qualified_name
880897

881898

882899
def main() -> None:

mypy/stubgenc.py

+57-18
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
"""
55

66
import importlib
7+
import inspect
78
import os.path
89
import re
9-
from typing import List, Dict, Tuple, Optional, Mapping, Any
10+
from typing import List, Dict, Tuple, Optional, Mapping, Any, Set
1011
from types import ModuleType
1112

12-
from mypy.stubutil import is_c_module, write_header, infer_sig_from_docstring
13+
from mypy.stubutil import (
14+
is_c_module, write_header, infer_sig_from_docstring,
15+
infer_prop_type_from_docstring
16+
)
1317

1418

1519
def generate_stub_for_c_module(module_name: str,
@@ -41,7 +45,7 @@ def generate_stub_for_c_module(module_name: str,
4145
for name, obj in items:
4246
if name.startswith('__') and name.endswith('__'):
4347
continue
44-
if name not in done:
48+
if name not in done and not inspect.ismodule(obj):
4549
type_str = type(obj).__name__
4650
if type_str not in ('int', 'str', 'bytes', 'float', 'bool'):
4751
type_str = 'Any'
@@ -67,7 +71,7 @@ def generate_stub_for_c_module(module_name: str,
6771

6872
def add_typing_import(output: List[str]) -> List[str]:
6973
names = []
70-
for name in ['Any']:
74+
for name in ['Any', 'Union', 'Tuple', 'Optional', 'List', 'Dict']:
7175
if any(re.search(r'\b%s\b' % name, line) for line in output):
7276
names.append(name)
7377
if names:
@@ -77,22 +81,30 @@ def add_typing_import(output: List[str]) -> List[str]:
7781

7882

7983
def is_c_function(obj: object) -> bool:
80-
return type(obj) is type(ord)
84+
return inspect.isbuiltin(obj) or type(obj) is type(ord)
8185

8286

8387
def is_c_method(obj: object) -> bool:
84-
return type(obj) in (type(str.index),
85-
type(str.__add__),
86-
type(str.__new__))
88+
return inspect.ismethoddescriptor(obj) or type(obj) in (type(str.index),
89+
type(str.__add__),
90+
type(str.__new__))
8791

8892

8993
def is_c_classmethod(obj: object) -> bool:
90-
type_str = type(obj).__name__
91-
return type_str == 'classmethod_descriptor'
94+
return inspect.isbuiltin(obj) or type(obj).__name__ in ('classmethod',
95+
'classmethod_descriptor')
96+
97+
98+
def is_c_property(obj: object) -> bool:
99+
return inspect.isdatadescriptor(obj) and hasattr(obj, 'fget')
100+
101+
102+
def is_c_property_readonly(prop: object) -> bool:
103+
return getattr(prop, 'fset') is None
92104

93105

94106
def is_c_type(obj: object) -> bool:
95-
return type(obj) is type(int)
107+
return inspect.isclass(obj) or type(obj) is type(int)
96108

97109

98110
def generate_c_function_stub(module: ModuleType,
@@ -104,6 +116,8 @@ def generate_c_function_stub(module: ModuleType,
104116
class_name: Optional[str] = None,
105117
class_sigs: Dict[str, str] = {},
106118
) -> None:
119+
ret_type = 'Any'
120+
107121
if self_var:
108122
self_arg = '%s, ' % self_var
109123
else:
@@ -115,19 +129,37 @@ def generate_c_function_stub(module: ModuleType,
115129
docstr = getattr(obj, '__doc__', None)
116130
inferred = infer_sig_from_docstring(docstr, name)
117131
if inferred:
118-
sig = inferred
132+
sig, ret_type = inferred
119133
else:
120134
if class_name and name not in sigs:
121135
sig = infer_method_sig(name)
122136
else:
123137
sig = sigs.get(name, '(*args, **kwargs)')
138+
# strip away parenthesis
124139
sig = sig[1:-1]
125140
if sig:
126-
if sig.split(',', 1)[0] == self_var:
127-
self_arg = ''
141+
if self_var:
142+
# remove annotation on self from signature if present
143+
groups = sig.split(',', 1)
144+
if groups[0] == self_var or groups[0].startswith(self_var + ':'):
145+
self_arg = ''
146+
sig = '{},{}'.format(self_var, groups[1]) if len(groups) > 1 else self_var
128147
else:
129148
self_arg = self_arg.replace(', ', '')
130-
output.append('def %s(%s%s): ...' % (name, self_arg, sig))
149+
output.append('def %s(%s%s) -> %s: ...' % (name, self_arg, sig, ret_type))
150+
151+
152+
def generate_c_property_stub(name: str, obj: object, output: List[str], readonly: bool) -> None:
153+
docstr = getattr(obj, '__doc__', None)
154+
inferred = infer_prop_type_from_docstring(docstr)
155+
if not inferred:
156+
inferred = 'Any'
157+
158+
output.append('@property')
159+
output.append('def {}(self) -> {}: ...'.format(name, inferred))
160+
if not readonly:
161+
output.append('@{}.setter'.format(name))
162+
output.append('def {}(self, val: {}) -> None: ...'.format(name, inferred))
131163

132164

133165
def generate_c_type_stub(module: ModuleType,
@@ -141,8 +173,9 @@ def generate_c_type_stub(module: ModuleType,
141173
# (it could be a mappingproxy!), which makes mypyc mad, so obfuscate it.
142174
obj_dict = getattr(obj, '__dict__') # type: Mapping[str, Any]
143175
items = sorted(obj_dict.items(), key=lambda x: method_name_sort_key(x[0]))
144-
methods = []
145-
done = set()
176+
methods = [] # type: List[str]
177+
properties = [] # type: List[str]
178+
done = set() # type: Set[str]
146179
for attr, value in items:
147180
if is_c_method(value) or is_c_classmethod(value):
148181
done.add(attr)
@@ -162,6 +195,10 @@ def generate_c_type_stub(module: ModuleType,
162195
attr = '__init__'
163196
generate_c_function_stub(module, attr, value, methods, self_var, sigs=sigs,
164197
class_name=class_name, class_sigs=class_sigs)
198+
elif is_c_property(value):
199+
done.add(attr)
200+
generate_c_property_stub(attr, value, properties, is_c_property_readonly(value))
201+
165202
variables = []
166203
for attr, value in items:
167204
if is_skipped_attribute(attr):
@@ -183,14 +220,16 @@ def generate_c_type_stub(module: ModuleType,
183220
bases_str = '(%s)' % ', '.join(base.__name__ for base in bases)
184221
else:
185222
bases_str = ''
186-
if not methods and not variables:
223+
if not methods and not variables and not properties:
187224
output.append('class %s%s: ...' % (class_name, bases_str))
188225
else:
189226
output.append('class %s%s:' % (class_name, bases_str))
190227
for variable in variables:
191228
output.append(' %s' % variable)
192229
for method in methods:
193230
output.append(' %s' % method)
231+
for prop in properties:
232+
output.append(' %s' % prop)
194233

195234

196235
def method_name_sort_key(name: str) -> Tuple[int, str]:

mypy/stubutil.py

+36-4
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,44 @@ def write_header(file: IO[str], module_name: Optional[str] = None,
106106
'# NOTE: This dynamically typed stub was automatically generated by stubgen.\n\n')
107107

108108

109-
def infer_sig_from_docstring(docstr: str, name: str) -> Optional[str]:
109+
def infer_sig_from_docstring(docstr: str, name: str) -> Optional[Tuple[str, str]]:
110110
if not docstr:
111111
return None
112112
docstr = docstr.lstrip()
113-
m = re.match(r'%s(\([a-zA-Z0-9_=, ]*\))' % name, docstr)
113+
# look for function signature, which is any string of the format
114+
# <function_name>(<signature>) -> <return type>
115+
# or perhaps without the return type
116+
117+
# in the signature, we allow the following characters:
118+
# colon/equal: to match default values, like "a: int=1"
119+
# comma/space/brackets: for type hints like "a: Tuple[int, float]"
120+
# dot: for classes annotating using full path, like "a: foo.bar.baz"
121+
# to capture return type,
122+
sig_str = r'\([a-zA-Z0-9_=:, \[\]\.]*\)'
123+
sig_match = r'%s(%s)' % (name, sig_str)
124+
# first, try to capture return type; we just match until end of line
125+
m = re.match(sig_match + ' -> ([a-zA-Z].*)$', docstr, re.MULTILINE)
114126
if m:
115-
return m.group(1)
116-
else:
127+
# strip potential white spaces at the right of return type
128+
return m.group(1), m.group(2).rstrip()
129+
130+
# try to not match return type
131+
m = re.match(sig_match, docstr)
132+
if m:
133+
return m.group(1), 'Any'
134+
return None
135+
136+
137+
def infer_prop_type_from_docstring(docstr: str) -> Optional[str]:
138+
if not docstr:
117139
return None
140+
141+
# check for Google/Numpy style docstring type annotation
142+
# the docstring has the format "<type>: <descriptions>"
143+
# in the type string, we allow the following characters
144+
# dot: because something classes are annotated using full path,
145+
# brackets: to allow type hints like List[int]
146+
# comma/space: things like Tuple[int, int]
147+
test_str = r'^([a-zA-Z0-9_, \.\[\]]*): '
148+
m = re.match(test_str, docstr)
149+
return m.group(1) if m else None

mypy/test/teststubgen.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from mypy.stubgenc import generate_c_type_stub, infer_method_sig
1818
from mypy.stubutil import (
1919
parse_signature, parse_all_signatures, build_signature, find_unique_signatures,
20-
infer_sig_from_docstring
20+
infer_sig_from_docstring, infer_prop_type_from_docstring
2121
)
2222

2323

@@ -103,12 +103,31 @@ def test_find_unique_signatures(self) -> None:
103103
('func3', '(arg, arg2)')])
104104

105105
def test_infer_sig_from_docstring(self) -> None:
106-
assert_equal(infer_sig_from_docstring('\nfunc(x) - y', 'func'), '(x)')
107-
assert_equal(infer_sig_from_docstring('\nfunc(x, Y_a=None)', 'func'), '(x, Y_a=None)')
106+
assert_equal(infer_sig_from_docstring('\nfunc(x) - y', 'func'), ('(x)', 'Any'))
107+
assert_equal(infer_sig_from_docstring('\nfunc(x, Y_a=None)', 'func'),
108+
('(x, Y_a=None)', 'Any'))
108109
assert_equal(infer_sig_from_docstring('\nafunc(x) - y', 'func'), None)
109110
assert_equal(infer_sig_from_docstring('\nfunc(x, y', 'func'), None)
110111
assert_equal(infer_sig_from_docstring('\nfunc(x=z(y))', 'func'), None)
111112
assert_equal(infer_sig_from_docstring('\nfunc x', 'func'), None)
113+
# try to infer signature from type annotation
114+
assert_equal(infer_sig_from_docstring('\nfunc(x: int)', 'func'), ('(x: int)', 'Any'))
115+
assert_equal(infer_sig_from_docstring('\nfunc(x: int=3)', 'func'), ('(x: int=3)', 'Any'))
116+
assert_equal(infer_sig_from_docstring('\nfunc(x: int=3) -> int', 'func'),
117+
('(x: int=3)', 'int'))
118+
assert_equal(infer_sig_from_docstring('\nfunc(x: int=3) -> int \n', 'func'),
119+
('(x: int=3)', 'int'))
120+
assert_equal(infer_sig_from_docstring('\nfunc(x: Tuple[int, str]) -> str', 'func'),
121+
('(x: Tuple[int, str])', 'str'))
122+
assert_equal(infer_sig_from_docstring('\nfunc(x: foo.bar)', 'func'),
123+
('(x: foo.bar)', 'Any'))
124+
125+
def infer_prop_type_from_docstring(self) -> None:
126+
assert_equal(infer_prop_type_from_docstring('str: A string.'), 'str')
127+
assert_equal(infer_prop_type_from_docstring('Optional[int]: An int.'), 'Optional[int]')
128+
assert_equal(infer_prop_type_from_docstring('Tuple[int, int]: A tuple.'),
129+
'Tuple[int, int]')
130+
assert_equal(infer_prop_type_from_docstring('\nstr: A string.'), None)
112131

113132

114133
class StubgenPythonSuite(DataSuite):

0 commit comments

Comments
 (0)