Skip to content

Commit 254f3d2

Browse files
committed
Pluggable system for producing types from docstrings
1 parent 0a9f88d commit 254f3d2

File tree

4 files changed

+124
-11
lines changed

4 files changed

+124
-11
lines changed

mypy/fastparse.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from mypy import defaults
2828
from mypy import experiments
2929
from mypy import messages
30+
from mypy import hooks
3031
from mypy.errors import Errors
3132

3233
try:
@@ -55,6 +56,8 @@
5556

5657
TYPE_COMMENT_SYNTAX_ERROR = 'syntax error in type comment'
5758
TYPE_COMMENT_AST_ERROR = 'invalid type comment or annotation'
59+
TYPE_COMMENT_DOCSTRING_ERROR = ('Arguments parsed from docstring are not '
60+
'present in function signature: {} not in {}')
5861

5962

6063
def parse(source: Union[str, bytes], fnam: str = None, errors: Errors = None,
@@ -109,6 +112,33 @@ def parse_type_comment(type_comment: str, line: int, errors: Errors) -> Optional
109112
return TypeConverter(errors, line=line).visit(typ.body)
110113

111114

115+
def parse_docstring(docstring: str, arg_names: List[str],
116+
line: int, errors: Errors) -> Optional[Tuple[List[Type], Type]]:
117+
"""Parse a docstring and return type representations.
118+
119+
Returns a 2-tuple: (list of arguments Types, and return Type).
120+
"""
121+
opts = hooks.options.get('docstring_parser', {})
122+
123+
def pop_and_convert(name: str) -> Optional[Type]:
124+
t = type_map.pop(name, None)
125+
if t is None:
126+
return AnyType()
127+
else:
128+
return parse_type_comment(t[0], line + t[1], errors)
129+
130+
if hooks.docstring_parser is not None:
131+
type_map = hooks.docstring_parser(docstring, opts, errors)
132+
if type_map:
133+
arg_types = [pop_and_convert(name) for name in arg_names]
134+
return_type = pop_and_convert('return')
135+
if type_map:
136+
errors.report(line, 0,
137+
TYPE_COMMENT_DOCSTRING_ERROR.format(type_map.keys(), arg_names))
138+
return arg_types, return_type
139+
return None
140+
141+
112142
def with_line(f: Callable[['ASTConverter', T], U]) -> Callable[['ASTConverter', T], U]:
113143
@wraps(f)
114144
def wrapper(self: 'ASTConverter', ast: T) -> U:
@@ -301,8 +331,9 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
301331
args = self.transform_args(n.args, n.lineno, no_type_check=no_type_check)
302332

303333
arg_kinds = [arg.kind for arg in args]
304-
arg_names = [arg.variable.name() for arg in args] # type: List[Optional[str]]
305-
arg_names = [None if argument_elide_name(name) else name for name in arg_names]
334+
real_names = [arg.variable.name() for arg in args] # type: List[str]
335+
arg_names = [None if argument_elide_name(name) else name
336+
for name in real_names] # type: List[Optional[str]]
306337
if special_function_elide_names(n.name):
307338
arg_names = [None] * len(arg_names)
308339
arg_types = None # type: List[Type]
@@ -342,6 +373,14 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
342373
else:
343374
arg_types = [a.type_annotation for a in args]
344375
return_type = TypeConverter(self.errors, line=n.lineno).visit(n.returns)
376+
# hooks
377+
if (not any(arg_types) and return_type is None and
378+
hooks.docstring_parser):
379+
doc = ast3.get_docstring(n, clean=False)
380+
if doc:
381+
types = parse_docstring(doc, real_names, n.lineno, self.errors)
382+
if types is not None:
383+
arg_types, return_type = types
345384

346385
for arg, arg_type in zip(args, arg_types):
347386
self.set_type_optional(arg_type, arg.initializer)

mypy/fastparse2.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@
4242
from mypy import experiments
4343
from mypy import messages
4444
from mypy.errors import Errors
45-
from mypy.fastparse import TypeConverter, parse_type_comment
45+
from mypy.fastparse import (TypeConverter, parse_type_comment,
46+
parse_docstring)
47+
from mypy import hooks
4648

4749
try:
4850
from typed_ast import ast27
@@ -290,8 +292,9 @@ def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement:
290292
args, decompose_stmts = self.transform_args(n.args, n.lineno)
291293

292294
arg_kinds = [arg.kind for arg in args]
293-
arg_names = [arg.variable.name() for arg in args] # type: List[Optional[str]]
294-
arg_names = [None if argument_elide_name(name) else name for name in arg_names]
295+
real_names = [arg.variable.name() for arg in args] # type: List[str]
296+
arg_names = [None if argument_elide_name(name) else name
297+
for name in real_names] # type: List[Optional[str]]
295298
if special_function_elide_names(n.name):
296299
arg_names = [None] * len(arg_names)
297300

@@ -326,6 +329,15 @@ def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement:
326329
else:
327330
arg_types = [a.type_annotation for a in args]
328331
return_type = converter.visit(None)
332+
# hooks
333+
if (not any(arg_types) and return_type is None and
334+
hooks.docstring_parser):
335+
doc = ast27.get_docstring(n, clean=False)
336+
if doc:
337+
types = parse_docstring(doc.decode('unicode_escape'),
338+
real_names, n.lineno, self.errors)
339+
if types is not None:
340+
arg_types, return_type = types
329341

330342
for arg, arg_type in zip(args, arg_types):
331343
self.set_type_optional(arg_type, arg.initializer)

mypy/hooks.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import Dict, Optional, Callable, Tuple
2+
from mypy.errors import Errors
3+
4+
Options = Dict[str, str]
5+
options = {} # type: Dict[str, Options]
6+
7+
# The docstring_parser hook is called for each unannotated function that has a
8+
# docstring. The callable should accept three arguments:
9+
# - the docstring to be parsed
10+
# - a dictionary of options (parsed from the [docstring_parser] section of mypy
11+
# config file)
12+
# - an Errors object for reporting errors, warnings, and info.
13+
#
14+
# The function should return a map from argument name to 2-tuple. The latter should contain:
15+
# - a PEP484-compatible string. The function's return type, if specified, is stored
16+
# in the mapping with the special key 'return'. Other than 'return', each key of
17+
# the mapping must be one of the arguments of the documented function; otherwise,
18+
# an error will be raised.
19+
# - a line number offset, relative to the start of the docstring, used to
20+
# improve errors if the associated type string is invalid.
21+
#
22+
docstring_parser = None # type: Callable[[str, Options, Errors], Optional[Dict[str, Tuple[str, int]]]]

mypy/main.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import re
88
import sys
99
import time
10+
from pydoc import locate
1011

11-
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple
12+
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Callable
1213

1314
from mypy import build
1415
from mypy import defaults
1516
from mypy import experiments
17+
from mypy import hooks
1618
from mypy import util
1719
from mypy.build import BuildSource, BuildResult, PYTHON_EXTENSIONS
1820
from mypy.errors import CompileError
@@ -548,6 +550,19 @@ def get_init_file(dir: str) -> Optional[str]:
548550
return None
549551

550552

553+
def load_hook(prefix: str, hook_name: str, hook_path: str) -> Optional[Callable]:
554+
# FIXME: no stubs for pydoc. should we write stubs or a simple replacement for locate?
555+
obj = locate(hook_path)
556+
if obj is None:
557+
print("%s: Could not find hook %s at %s" %
558+
(prefix, hook_name, hook_path), file=sys.stderr)
559+
if not callable(obj):
560+
print("%s: Hook %s at %s is not callable" %
561+
(prefix, hook_name, hook_path), file=sys.stderr)
562+
return None
563+
return obj
564+
565+
551566
# For most options, the type of the default value set in options.py is
552567
# sufficient, and we don't have to do anything here. This table
553568
# exists to specify types for values initialized to None or container
@@ -603,19 +618,33 @@ def parse_config_file(options: Options, filename: Optional[str]) -> None:
603618
else:
604619
section = parser['mypy']
605620
prefix = '%s: [%s]' % (file_read, 'mypy')
606-
updates, report_dirs = parse_section(prefix, options, section)
621+
updates, report_dirs, hook_funcs = parse_section(prefix, options, section)
607622
for k, v in updates.items():
608623
setattr(options, k, v)
624+
625+
# bind hook functions to hooks module
626+
for k, v in hook_funcs.items():
627+
hook_func = load_hook(prefix, k, v)
628+
if hook_func is not None:
629+
# FIXME: dynamically check loaded function annotations against those in `hooks`?
630+
setattr(hooks, k, hook_func)
631+
# look for an options section for this hook
632+
if k in parser:
633+
hooks.options[k] = dict(parser[k])
609634
options.report_dirs.update(report_dirs)
610635

611636
for name, section in parser.items():
612637
if name.startswith('mypy-'):
613638
prefix = '%s: [%s]' % (file_read, name)
614-
updates, report_dirs = parse_section(prefix, options, section)
639+
updates, report_dirs, hook_funcs = parse_section(prefix, options, section)
615640
if report_dirs:
616641
print("%s: Per-module sections should not specify reports (%s)" %
617642
(prefix, ', '.join(s + '_report' for s in sorted(report_dirs))),
618643
file=sys.stderr)
644+
if hook_funcs:
645+
print("%s: Per-module sections should not specify hooks (%s)" %
646+
(prefix, ', '.join(sorted(hook_funcs))),
647+
file=sys.stderr)
619648
if set(updates) - Options.PER_MODULE_OPTIONS:
620649
print("%s: Per-module sections should only specify per-module flags (%s)" %
621650
(prefix, ', '.join(sorted(set(updates) - Options.PER_MODULE_OPTIONS))),
@@ -632,16 +661,27 @@ def parse_config_file(options: Options, filename: Optional[str]) -> None:
632661

633662

634663
def parse_section(prefix: str, template: Options,
635-
section: Mapping[str, str]) -> Tuple[Dict[str, object], Dict[str, str]]:
664+
section: Mapping[str, str]) -> Tuple[Dict[str, object],
665+
Dict[str, str], Dict[str, str]]:
636666
"""Parse one section of a config file.
637667
638668
Returns a dict of option values encountered, and a dict of report directories.
639669
"""
640670
results = {} # type: Dict[str, object]
641671
report_dirs = {} # type: Dict[str, str]
672+
hook_funcs = {} # type: Dict[str, str]
642673
for key in section:
643674
key = key.replace('-', '_')
644-
if key in config_types:
675+
if key.startswith('hooks.'):
676+
dv = section.get(key)
677+
key = key[6:]
678+
if not hasattr(hooks, key):
679+
print("%s: Unrecognized hook: %s = %s" % (prefix, key, dv),
680+
file=sys.stderr)
681+
else:
682+
hook_funcs[key] = dv
683+
continue
684+
elif key in config_types:
645685
ct = config_types[key]
646686
else:
647687
dv = getattr(template, key, None)
@@ -685,7 +725,7 @@ def parse_section(prefix: str, template: Options,
685725
if 'follow_imports' not in results:
686726
results['follow_imports'] = 'error'
687727
results[key] = v
688-
return results, report_dirs
728+
return results, report_dirs, hook_funcs
689729

690730

691731
def fail(msg: str) -> None:

0 commit comments

Comments
 (0)