Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions src/docstub/_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import sys
import time
from collections import Counter
from contextlib import contextmanager
from pathlib import Path

Expand All @@ -20,6 +21,7 @@
walk_source,
walk_source_and_targets,
)
from ._utils import ErrorReporter, GroupedErrorReporter
from ._version import __version__

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -143,10 +145,16 @@ def report_execution_time():
type=click.Path(exists=True, dir_okay=False),
help="Set configuration file explicitly.",
)
@click.option(
"--group-errors",
is_flag=True,
help="Group errors by type and content. "
"Will delay showing errors until all files have been processed.",
)
@click.option("-v", "--verbose", count=True, help="Log more details.")
@click.help_option("-h", "--help")
@report_execution_time()
def main(source_dir, out_dir, config_path, verbose):
def main(source_dir, out_dir, config_path, group_errors, verbose):
"""
Parameters
----------
Expand All @@ -155,25 +163,30 @@ def main(source_dir, out_dir, config_path, verbose):
config_path : Path
verbose : str
"""

# Setup -------------------------------------------------------------------

_setup_logging(verbose=verbose)

source_dir = Path(source_dir)
config = _load_configuration(config_path)
known_imports = _build_import_map(config, source_dir)

reporter = GroupedErrorReporter() if group_errors else ErrorReporter()
types_db = TypesDatabase(
source_pkgs=[source_dir.parent.resolve()], known_imports=known_imports
)
# and the stub transformer
stub_transformer = Py2StubTransformer(
types_db=types_db, replace_doctypes=config.replace_doctypes
types_db=types_db, replace_doctypes=config.replace_doctypes, reporter=reporter
)

if not out_dir:
out_dir = source_dir.parent / (source_dir.name + "-stubs")
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)

# Stub generation ---------------------------------------------------------

for source_path, stub_path in walk_source_and_targets(source_dir, out_dir):
if source_path.suffix.lower() == ".pyi":
logger.debug("using existing stub file %s", source_path)
Expand All @@ -199,18 +212,25 @@ def main(source_dir, out_dir, config_path, verbose):
logger.info("wrote %s", stub_path)
fo.write(stub_content)

# Reporting --------------------------------------------------------------

if group_errors:
reporter.print_grouped()

# Report basic statistics
successful_queries = types_db.stats["successful_queries"]
click.secho(f"{successful_queries} matched annotations", fg="green")

grammar_error_count = stub_transformer.transformer.stats["grammar_errors"]
if grammar_error_count:
click.secho(f"{grammar_error_count} grammar violations", fg="red")
syntax_error_count = stub_transformer.transformer.stats["syntax_errors"]
if syntax_error_count:
click.secho(f"{syntax_error_count} syntax errors", fg="red")

unknown_doctypes = types_db.stats["unknown_doctypes"]
if unknown_doctypes:
click.secho(f"{len(unknown_doctypes)} unknown doctypes:", fg="red")
click.echo(" " + "\n ".join(set(unknown_doctypes)))
counter = Counter(unknown_doctypes)
for item, count in sorted(counter.items(), key=lambda x: x[1]):
click.echo(f" {item} (x{count})")

if unknown_doctypes or grammar_error_count:
if unknown_doctypes or syntax_error_count:
sys.exit(1)
2 changes: 1 addition & 1 deletion src/docstub/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def from_default(cls):
return config

def merge(self, other):
"""Merge contents with other and return a new Config instance.
"""Merge contents with other and return a copy_with Config instance.

Parameters
----------
Expand Down
64 changes: 38 additions & 26 deletions src/docstub/_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpydoc.docscrape as npds

from ._analysis import KnownImport, TypesDatabase
from ._utils import ContextFormatter, DocstubError, accumulate_qualname, escape_qualname
from ._utils import DocstubError, ErrorReporter, accumulate_qualname, escape_qualname

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,7 +80,7 @@ def many_as_tuple(cls, types):

@classmethod
def as_generator(cls, *, yield_types, receive_types=(), return_types=()):
"""Create new ``Generator`` type from yield, receive and return types.
"""Create copy_with ``Generator`` type from yield, receive and return types.

Parameters
----------
Expand Down Expand Up @@ -257,7 +257,7 @@ def __init__(self, *, types_db=None, replace_doctypes=None, **kwargs):

super().__init__(**kwargs)

self.stats = {"grammar_errors": 0}
self.stats = {"syntax_errors": 0}

def doctype_to_annotation(self, doctype):
"""Turn a type description in a docstring into a type annotation.
Expand Down Expand Up @@ -289,7 +289,7 @@ def doctype_to_annotation(self, doctype):
lark.exceptions.ParseError,
QualnameIsKeyword,
):
self.stats["grammar_errors"] += 1
self.stats["syntax_errors"] += 1
raise
finally:
self._collected_imports = None
Expand Down Expand Up @@ -443,7 +443,7 @@ class DocstringAnnotations:
----------
docstring : str
transformer : DoctypeTransformer
ctx : ~.ContextFormatter
reporter : ~.ErrorReporter

Examples
--------
Expand All @@ -457,32 +457,32 @@ class DocstringAnnotations:
>>> transformer = DoctypeTransformer()
>>> annotations = DocstringAnnotations(docstring, transformer=transformer)
>>> annotations.parameters.keys()
invalid syntax in doctype
Invalid syntax in docstring type annotation
some invalid syntax
^
<BLANKLINE>
unknown name in doctype: 'unknown.symbol'
Unknown name in doctype: 'unknown.symbol'
unknown.symbol
^^^^^^^^^^^^^^
<BLANKLINE>
dict_keys(['a', 'b', 'c'])
"""

def __init__(self, docstring, *, transformer, ctx=None):
def __init__(self, docstring, *, transformer, reporter=None):
"""
Parameters
----------
docstring : str
transformer : DoctypeTransformer
ctx : ~.ContextFormatter, optional
reporter : ~.ErrorReporter, optional
"""
self.docstring = docstring
self.np_docstring = npds.NumpyDocString(docstring)
self.transformer = transformer

if ctx is None:
ctx = ContextFormatter(line=0)
self.ctx: ContextFormatter = ctx
if reporter is None:
reporter = ErrorReporter(line=0)
self.reporter: ErrorReporter = reporter

def _doctype_to_annotation(self, doctype, ds_line=0):
"""Convert a type description to a Python-ready type.
Expand All @@ -501,7 +501,7 @@ def _doctype_to_annotation(self, doctype, ds_line=0):
The transformed type, ready to be inserted into a stub file, with
necessary imports attached.
"""
ctx = self.ctx.with_line(offset=ds_line)
reporter = self.reporter.copy_with(line_offset=ds_line)

try:
annotation, unknown_qualnames = self.transformer.doctype_to_annotation(
Expand All @@ -513,21 +513,23 @@ def _doctype_to_annotation(self, doctype, ds_line=0):
if hasattr(error, "get_context"):
details = error.get_context(doctype)
details = details.replace("^", click.style("^", fg="red", bold=True))
ctx.print_message("invalid syntax in doctype", details=details)
reporter.message(
"Invalid syntax in docstring type annotation", details=details
)
return FallbackAnnotation

except lark.visitors.VisitError as e:
tb = "\n".join(traceback.format_exception(e.orig_exc))
details = f"doctype: {doctype!r}\n\n{tb}"
ctx.print_message("unexpected error while parsing doctype", details=details)
reporter.message("unexpected error while parsing doctype", details=details)
return FallbackAnnotation

else:
for name, start_col, stop_col in unknown_qualnames:
width = stop_col - start_col
error_underline = click.style("^" * width, fg="red", bold=True)
details = f"{doctype}\n{' ' * start_col}{error_underline}\n"
ctx.print_message(f"unknown name in doctype: {name!r}", details=details)
reporter.message(f"Unknown name in doctype: {name!r}", details=details)
return annotation

@cached_property
Expand All @@ -553,7 +555,10 @@ def attributes(self):
break

if attribute.name in annotations:
logger.warning("duplicate parameter name %r, ignoring", attribute.name)
self.reporter.message(
"duplicate attribute name in docstring",
details=self.reporter.underline(attribute.name),
)
continue

annotation = self._doctype_to_annotation(attribute.type, ds_line=ds_line)
Expand All @@ -576,7 +581,10 @@ def parameters(self):

duplicates = param_section.keys() & other_section.keys()
for duplicate in duplicates:
logger.warning("duplicate parameter name %r, ignoring", duplicate)
self.reporter.message(
"duplicate attribute name in docstring",
details=self.reporter.underline(duplicate),
)

# Last takes priority
paramaters = other_section | param_section
Expand Down Expand Up @@ -653,20 +661,21 @@ def _handle_missing_whitespace(self, param):
param : numpydoc.docscrape.Parameter
"""
if ":" in param.name and param.type == "":
msg = (
"Possibly missing whitespace between parameter and colon in "
"docstring, make sure to include it so that the type is parsed "
"properly!"
msg = "Possibly missing whitespace between parameter and colon in docstring"
underline = "".join("^" if c == ":" else " " for c in param.name)
underline = click.style(underline, fg="red", bold=True)
hint = (
f"{param.name}\n{underline}"
f"\nInclude whitespace so that the type is parsed properly!"
)
hint = f"{param.name}"

ds_line = 0
for i, line in enumerate(self.docstring.split("\n")):
if param.name in line:
ds_line = i
break
ctx = self.ctx.with_line(offset=ds_line)
ctx.print_message(msg, details=hint)
reporter = self.reporter.copy_with(line_offset=ds_line)
reporter.message(msg, details=hint)

new_name, new_type = param.name.split(":", maxsplit=1)
param = npds.Parameter(name=new_name, type=new_type, desc=param.desc)
Expand All @@ -693,7 +702,10 @@ def _section_annotations(self, name):

if param.name in annotated_params:
# TODO make error
logger.warning("duplicate parameter name %r, ignoring", param.name)
self.reporter.message(
"duplicate parameter / attribute name in docstring",
details=self.reporter.underline(param.name),
)
continue

if param.type:
Expand Down
42 changes: 30 additions & 12 deletions src/docstub/_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ._analysis import KnownImport
from ._docstrings import DocstringAnnotations, DoctypeTransformer
from ._utils import ContextFormatter, module_name_from_path
from ._utils import ErrorReporter, module_name_from_path

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -350,18 +350,23 @@ def print_upper(x: Incomplete) -> None: ...
)
_Annotation_None: ClassVar[cst.Annotation] = cst.Annotation(cst.Name("None"))

def __init__(self, *, types_db=None, replace_doctypes=None):
def __init__(self, *, types_db=None, replace_doctypes=None, reporter=None):
"""
Parameters
----------
types_db : ~.TypesDatabase
replace_doctypes : dict[str, str]
reporter : ~.ErrorReporter
"""
if reporter is None:
reporter = ErrorReporter()

self.types_db = types_db
self.replace_doctypes = replace_doctypes
self.transformer = DoctypeTransformer(
types_db=types_db, replace_doctypes=replace_doctypes
)
self.reporter = reporter
# Relevant docstring for the current context
self._scope_stack = None # Entered module, class or function scopes
self._pytypes_stack = None # Collected pytypes for each stack
Expand Down Expand Up @@ -544,11 +549,17 @@ def leave_FunctionDef(self, original_node, updated_node):
position = self.get_metadata(
cst.metadata.PositionProvider, original_node
).start
ctx = ContextFormatter(path=self.current_source, line=position.line)
reporter = self.reporter.copy_with(
path=self.current_source, line=position.line
)
replaced = _inline_node_as_code(original_node.returns.annotation)
ctx.print_message(
short="replacing existing inline return annotation",
details=f"{replaced}\n{"^" * len(replaced)} -> {annotation_value}",
details = (
f"{replaced}\n"
f"{reporter.underline(replaced)} -> {annotation_value}"
)
reporter.message(
short="Replacing existing inline return annotation",
details=details,
)

annotation = cst.Annotation(cst.parse_expression(annotation_value))
Expand Down Expand Up @@ -735,13 +746,18 @@ def leave_AnnAssign(self, original_node, updated_node):
position = self.get_metadata(
cst.metadata.PositionProvider, original_node
).start
ctx = ContextFormatter(path=self.current_source, line=position.line)
reporter = self.reporter.copy_with(
path=self.current_source, line=position.line
)
replaced = cst.Module([]).code_for_node(
updated_node.annotation.annotation
)
ctx.print_message(
short="replacing existing inline annotation",
details=f"{replaced}\n{"^" * len(replaced)} -> {pytype.value}",
details = (
f"{replaced}\n{reporter.underline(replaced)} -> {pytype.value}"
)
reporter.message(
short="Replacing existing inline annotation",
details=details,
)

updated_node = updated_node.with_deep_changes(
Expand Down Expand Up @@ -901,12 +917,14 @@ def _annotations_from_node(self, node):
position = self.get_metadata(
cst.metadata.PositionProvider, docstring_node
).start
ctx = ContextFormatter(path=self.current_source, line=position.line)
reporter = self.reporter.copy_with(
path=self.current_source, line=position.line
)
try:
annotations = DocstringAnnotations(
docstring_node.evaluated_value,
transformer=self.transformer,
ctx=ctx,
reporter=reporter,
)
except (SystemExit, KeyboardInterrupt):
raise
Expand Down
Loading