diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index dfc7562..64e53e5 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -38,9 +38,11 @@ class ExampleClass: def method_in_nested_class(self, a1: complex) -> None: ... def __init__(self, a1: str, a2: float = ...) -> None: ... - def method(self, a1: float, a2: float | None) -> list[float]: ... + def method( + self, a1: float, a2: float = ..., a3: float | None = ... + ) -> list[float]: ... @staticmethod - def some_staticmethod(a1: float, a2: float | None = ...) -> dict[str, Any]: ... + def some_staticmethod(a1: float, a2: str = ...) -> dict[str, Any]: ... @property def some_property(self) -> str: ... @some_property.setter diff --git a/examples/example_pkg-stubs/_numpy.pyi b/examples/example_pkg-stubs/_numpy.pyi index 7d53ba1..9080523 100644 --- a/examples/example_pkg-stubs/_numpy.pyi +++ b/examples/example_pkg-stubs/_numpy.pyi @@ -6,7 +6,7 @@ def func_object_with_numpy_objects( a1: np.int8, a2: np.int16, a3: np.typing.DTypeLike, a4: np.typing.DTypeLike ) -> None: ... def func_ndarray( - a1: NDArray, a2: np.NDArray, a3: NDArray[float], a4: NDArray[np.uint8] = ... + a1: NDArray, a2: np.NDArray, a3: NDArray[float], a4: NDArray[np.uint8] | None = ... ) -> tuple[NDArray[np.uint8], NDArray[complex]]: ... def func_array_like( a1: ArrayLike, a2: ArrayLike, a3: ArrayLike[float], a4: ArrayLike[np.uint8] diff --git a/examples/example_pkg/_basic.py b/examples/example_pkg/_basic.py index cdaa5a2..4c98b60 100644 --- a/examples/example_pkg/_basic.py +++ b/examples/example_pkg/_basic.py @@ -90,13 +90,14 @@ def method_in_nested_class(self, a1): def __init__(self, a1, a2=0): pass - def method(self, a1, a2): + def method(self, a1, a2=0, a3=None): """Dummy. Parameters ---------- a1 : float a2 : float, optional + a3 : float, optional Returns ------- @@ -110,7 +111,7 @@ def some_staticmethod(a1, a2="uno"): Parameters ---------- a1 : float - a2 : float, optional + a2 : str, optional Returns ------- diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index 77058ce..7b85107 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -10,7 +10,7 @@ import libcst as cst -from ._utils import accumulate_qualname +from ._utils import accumulate_qualname, module_name_from_path logger = logging.getLogger(__name__) @@ -42,16 +42,21 @@ def _shared_leading_path(*paths): class KnownImport: """Import information associated with a single known type annotation. - Parameters + Attributes ---------- - import_name : - Dotted names after "import". - import_path : + import_path : str, optional Dotted names after "from". - import_alias : + import_name : str, optional + Dotted names after "import". + import_alias : str, optional Name (without ".") after "as". - builtin_name : + builtin_name : str, optional Names an object that's builtin and doesn't need an import. + + Examples + -------- + >>> KnownImport(import_path="numpy", import_name="uint8", import_alias="ui8") + """ import_name: str = None @@ -170,14 +175,6 @@ def __str__(self): return out -@dataclass(slots=True, frozen=True) -class InspectionContext: - """Currently inspected module and other information.""" - - file_path: Path - in_package_path: str - - def _is_type(value) -> bool: """Check if value is a type.""" # Checking for isinstance(..., type) isn't enough, some types such as @@ -262,45 +259,57 @@ def common_known_imports(): return known_imports -class KnownImportCollector(cst.CSTVisitor): +class TypeCollector(cst.CSTVisitor): @classmethod - def collect(cls, file, module_name): + def collect(cls, file): + """Collect importable type annotations in given file. + + Parameters + ---------- + file : Path + + Returns + ------- + collected : dict[str, KnownImport] + """ file = Path(file) with file.open("r") as fo: source = fo.read() tree = cst.parse_module(source) - collector = cls(module_name=module_name) + collector = cls(module_name=module_name_from_path(file)) tree.visit(collector) return collector.known_imports def __init__(self, *, module_name): + """Initialize type collector. + + Parameters + ---------- + module_name : str + """ self.module_name = module_name self._stack = [] self.known_imports = {} - def visit_ClassDef(self, node): + def visit_ClassDef(self, node: cst.ClassDef) -> bool: self._stack.append(node.name.value) class_name = ".".join(self._stack[:1]) qualname = f"{self.module_name}.{'.'.join(self._stack)}" - - known_import = KnownImport( - import_name=class_name, - import_path=self.module_name, - ) + known_import = KnownImport(import_path=self.module_name, import_name=class_name) self.known_imports[qualname] = known_import return True - def leave_ClassDef(self, original_node): + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: self._stack.pop() - def visit_FunctionDef(self, node): + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: self._stack.append(node.name.value) return True - def leave_FunctionDef(self, original_node): + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: self._stack.pop() @@ -395,7 +404,8 @@ def query(self, search_name): if known_import is None and self.current_source: # Try scope of current module - try_qualname = f"{self.current_source.import_path}.{search_name}" + module_name = module_name_from_path(self.current_source) + try_qualname = f"{module_name}.{search_name}" known_import = self.known_imports.get(try_qualname) if known_import: annotation_name = search_name diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 9dbe669..848ccd3 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -6,8 +6,8 @@ from ._analysis import ( KnownImport, - KnownImportCollector, StaticInspector, + TypeCollector, common_known_imports, ) from ._config import Config @@ -92,9 +92,7 @@ def main(source_dir, out_dir, config_path, verbose): known_imports = common_known_imports() for source_path in walk_source(source_dir): logger.info("collecting types in %s", source_path) - known_imports_in_source = KnownImportCollector.collect( - source_path, module_name=source_path.import_path - ) + known_imports_in_source = TypeCollector.collect(source_path) known_imports.update(known_imports_in_source) known_imports.update(KnownImport.many_from_config(config.known_imports)) diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 746313d..d7ba3d4 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -48,6 +48,9 @@ def __post_init__(self): object.__setattr__(self, "imports", frozenset(self.imports)) if "~" in self.value: raise ValueError(f"unexpected '~' in annotation value: {self.value}") + for import_ in self.imports: + if not isinstance(import_, KnownImport): + raise TypeError(f"unexpected type {type(import_)} in `imports`") def __str__(self) -> str: return self.value @@ -95,6 +98,22 @@ def as_yields_generator(cls, yield_types, receive_types=()): # TODO raise NotImplementedError() + def as_optional(self): + """Return optional version of this annotation by appending `| None`. + + Returns + ------- + optional : Annotation + + Examples + -------- + >>> Annotation(value="int").as_optional() + Annotation(value='int | None', imports=frozenset()) + """ + value = f"{self.value} | None" + optional = type(self)(value=value, imports=self.imports) + return optional + @staticmethod def _aggregate_annotations(*types): """Aggregate values and imports of given Annotations. @@ -118,14 +137,7 @@ def _aggregate_annotations(*types): GrammarErrorFallback = Annotation( value="Any", - imports=frozenset( - ( - KnownImport( - import_name="Any", - import_path="typing", - ), - ) - ), + imports=frozenset((KnownImport(import_path="typing", import_name="Any"),)), ) @@ -233,12 +245,8 @@ def types_or(self, tree): return out def optional(self, tree): - out = "None" - literal = [child for child in tree.children if child.type == "LITERAL"] - assert len(literal) <= 1 - if literal: - out = lark.Discard # Type should cover the default - return out + logger.debug("dropping optional / default info") + return lark.Discard def extra_info(self, tree): logger.debug("dropping extra info") diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index 5934990..d903ac1 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -3,53 +3,17 @@ import enum import logging from dataclasses import dataclass -from pathlib import Path import libcst as cst import libcst.matchers as cstm +from ._analysis import KnownImport from ._docstrings import DocstringAnnotations, DoctypeTransformer -from ._utils import ContextFormatter +from ._utils import ContextFormatter, module_name_from_path logger = logging.getLogger(__name__) -class PackageFile(Path): - """File in a Python package.""" - - def __init__(self, *args, package_root): - """ - Parameters - ---------- - args : tuple[Any, ...] - package_root : Path - """ - self.package_root = package_root - super().__init__(*args) - if self.is_dir(): - raise ValueError("mustn't be a directory") - if not self.is_relative_to(self.package_root): - raise ValueError("path must be relative to package_root") - - @property - def import_path(self): - """ - Returns - ------- - str - """ - relative_to_root = self.relative_to(self.package_root) - parts = relative_to_root.with_suffix("").parts - parts = (self.package_root.name, *parts) - if parts[-1] == "__init__": - parts = parts[:-1] - import_name = ".".join(parts) - return import_name - - def with_segments(self, *args): - return Path(*args) - - def _is_python_package(path): """ Parameters @@ -71,17 +35,11 @@ def walk_source(root_dir): ---------- root_dir : Path Root directory of a Python package. - target_dir : Path - Root directory in which a matching stub package will be created. Yields ------ - source_path : PackageFile + source_path : Path Either a Python file or a stub file that takes precedence. - - Notes - ----- - Files starting with "test_" are skipped entirely for now. """ queue = [root_dir] while queue: @@ -102,8 +60,7 @@ def walk_source(root_dir): if suffix == ".py" and path.with_suffix(".pyi").exists(): continue # Stub file already exists and takes precedence - python_file = PackageFile(path, package_root=root_dir) - yield python_file + yield path def walk_source_and_targets(root_dir, target_dir): @@ -118,9 +75,9 @@ def walk_source_and_targets(root_dir, target_dir): Returns ------- - source_path : PackageFile + source_path : Path Either a Python file or a stub file that takes precedence. - stub_path : PackageFile + stub_path : Path Target stub file. Notes @@ -129,7 +86,6 @@ def walk_source_and_targets(root_dir, target_dir): """ for source_path in walk_source(root_dir): stub_path = target_dir / source_path.with_suffix(".pyi").relative_to(root_dir) - stub_path = PackageFile(stub_path, package_root=target_dir) yield source_path, stub_path @@ -242,7 +198,7 @@ class Py2StubTransformer(cst.CSTTransformer): _Annotation_Any = cst.Annotation(cst.Name("Any")) _Annotation_None = cst.Annotation(cst.Name("None")) - def __init__(self, *, inspector, replace_doctypes): + def __init__(self, *, inspector=None, replace_doctypes=None): """ Parameters ---------- @@ -260,13 +216,25 @@ def __init__(self, *, inspector, replace_doctypes): self._required_imports = None # Collect imports for used types self._current_module = None + self._current_source = None # Use via property `current_source` + + @property + def current_source(self): + return self._current_source + + @current_source.setter + def current_source(self, value): + self._current_source = value + if self.inspector is not None: + self.inspector.current_source = value + def python_to_stub(self, source, *, module_path=None): """Convert Python source code to stub-file ready code. Parameters ---------- source : str - module_path : PackageFile, optional + module_path : Path, optional The location of the source that is transformed into a stub file. If given, used to enhance logging & error messages with more context information. @@ -279,8 +247,7 @@ def python_to_stub(self, source, *, module_path=None): self._scope_stack = [] self._pytypes_stack = [] self._required_imports = set() - if module_path: - self.inspector.current_source = module_path + self.current_source = module_path source_tree = cst.parse_module(source) source_tree = cst.metadata.MetadataWrapper(source_tree) @@ -292,7 +259,7 @@ def python_to_stub(self, source, *, module_path=None): self._scope_stack = None self._pytypes_stack = None self._required_imports = None - self.inspector.current_source = None + self.current_source = None def visit_ClassDef(self, node): """Collect pytypes from class docstring and add scope to stack. @@ -391,14 +358,21 @@ def leave_Param(self, original_node, updated_node): is_self_or_cls = ( scope.node.params.children[0] is original_node and scope.has_self_or_cls ) + defaults_to_none = cstm.matches(updated_node.default, cstm.Name(value="None")) + + if updated_node.default is not None: + node_changes["default"] = cst.Ellipsis() name = original_node.name.value pytypes = self._pytypes_stack[-1] if not pytypes and scope.is_class_init: pytypes = self._pytypes_stack[-2] + if pytypes: pytype = pytypes.parameters.get(name) if pytype: + if defaults_to_none: + pytype = pytype.as_optional() annotation = cst.Annotation(cst.parse_expression(pytype.value)) node_changes["annotation"] = annotation if pytype.imports: @@ -407,11 +381,8 @@ def leave_Param(self, original_node, updated_node): # Potentially use "Any" except for first param in (class)methods elif not is_self_or_cls and updated_node.annotation is None: node_changes["annotation"] = self._Annotation_Any - _, known_import = self.inspector.query("Any") - self._required_imports.add(known_import) - - if updated_node.default is not None: - node_changes["default"] = cst.Ellipsis() + import_ = KnownImport(import_path="typing", import_name="Any") + self._required_imports.add(import_) if node_changes: updated_node = updated_node.with_changes(**node_changes) @@ -496,10 +467,15 @@ def leave_Module(self, original_node, updated_node): ------- updated_node : cst.Module """ - current_module = self.inspector.current_source.import_path - required_imports = [ - imp for imp in self._required_imports if imp.import_path != current_module - ] + required_imports = self._required_imports.copy() + current_module = None + if self.current_source: + current_module = module_name_from_path(self.current_source) + required_imports = [ + imp + for imp in self._required_imports + if imp.import_path != current_module + ] import_nodes = self._parse_imports( required_imports, current_module=current_module ) @@ -617,9 +593,7 @@ def _annotations_from_node(self, node): cst.metadata.PositionProvider, docstring_node ).start - ctx = ContextFormatter( - path=Path(self.inspector.current_source), line=position.line - ) + ctx = ContextFormatter(path=self.current_source, line=position.line) try: annotations = DocstringAnnotations( docstring_node.evaluated_value, diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index 99dde09..eeb6779 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -1,6 +1,7 @@ import dataclasses import itertools import re +from functools import lru_cache from pathlib import Path from textwrap import indent @@ -63,6 +64,47 @@ def escape_qualname(name): return qualname +@lru_cache(maxsize=10) +def module_name_from_path(path): + """Find the full name of a module within its package from its file path. + + Parameters + ---------- + path : Path + + Returns + ------- + name : str + + Examples + -------- + >>> from pathlib import Path + >>> module_name_from_path(Path(__file__)) + 'docstub._utils' + >>> import docstub + >>> module_name_from_path(Path(docstub.__file__)) + 'docstub' + """ + if not path.is_file(): + raise FileNotFoundError(f"`path` is not an existing file: {path!r}") + + name_parts = [] + if path.name != "__init__.py": + name_parts.insert(0, path.stem) + + directory = path.parent + while True: + is_in_package = (directory / "__init__.py").is_file() + if is_in_package: + name_parts.insert(0, directory.name) + directory = directory.parent + else: + break + + name = ".".join(name_parts) + return name + + @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) class ContextFormatter: """Format messages in context of a location in a file. @@ -191,3 +233,8 @@ def print_message(self, short, *, details=None): """ msg = self.format_message(short, details=details, ansi_styles=True) click.echo(msg) + + def __post_init__(self): + if self.path is not None and not isinstance(self.path, Path): + msg = f"expected `path` to be of type `Path`, got {type(self.path)!r}" + raise TypeError(msg) diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index e626806..2f9a03f 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -16,7 +16,7 @@ optional : "optional" extra_info : /[^\r\n]+/ -sphinx_ref : ":" (NAME ":")? NAME ":`" qualname "`" +sphinx_ref : (":" (NAME ":")? NAME ":")? "`" qualname "`" container: qualname "[" types_or ("," types_or)* ("," PY_ELLIPSES)? "]" | qualname "of" type // TODO allow plural somehow, e.g. "list of int(s)"? @@ -49,7 +49,7 @@ leading_optional : "[" dim ("," dim)* ",]" insert_optional : "[," dim ("," dim)* "]" ?dim : NUMBER | PY_ELLIPSES - | NAME + | NAME ("=" NUMBER)? // ---------------------------------------------------------------------------- diff --git a/tests/test_docstrings.py b/tests/test_docstrings.py index bce50bb..1206ed7 100644 --- a/tests/test_docstrings.py +++ b/tests/test_docstrings.py @@ -73,11 +73,13 @@ def test_literals(self, doctype, expected): @pytest.mark.parametrize( ("doctype", "expected"), [ - ("int, optional", "int | None"), - # None isn't appended, since the type should cover the default - ("int, default 1", "int"), + ("int, optional", "int"), + ("int | None, optional", "int | None"), + ("int, default -1", "int"), ("int, default = 1", "int"), - ("int, default: 1", "int"), + ("int, default: 0", "int"), + ("float, default: 1.0", "float"), + ("{'a', 'b'}, default : 'a'", "Literal['a', 'b']"), ], ) @pytest.mark.parametrize("extra_info", [None, "int", ", extra, info"]) @@ -88,6 +90,20 @@ def test_optional_extra_info(self, doctype, expected, extra_info): annotation, _ = transformer.doctype_to_annotation(doctype) assert annotation.value == expected + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + ("`Generator`", "Generator"), + (":class:`Generator`", "Generator"), + (":py:class:`Generator`", "Generator"), + ("list[:py:class:`Generator`]", "list[Generator]"), + ], + ) + def test_sphinx_ref(self, doctype, expected): + transformer = DoctypeTransformer() + annotation, _ = transformer.doctype_to_annotation(doctype) + assert annotation.value == expected + # fmt: off @pytest.mark.parametrize( ("fmt", "expected_fmt"), @@ -152,7 +168,6 @@ def test_multiple_unknown_names(self): class Test_DocstringAnnotations: - def test_empty_docstring(self): docstring = dedent("""No sections in this docstring.""") transformer = DoctypeTransformer() @@ -165,7 +180,7 @@ def test_empty_docstring(self): [ ("bool", "bool"), ("str, extra information", "str"), - ("list of int, optional", "list[int] | None"), + ("list of int, optional", "list[int]"), ], ) def test_parameters(self, doctype, expected): diff --git a/tests/test_stubs.py b/tests/test_stubs.py index 45759ec..d14b70e 100644 --- a/tests/test_stubs.py +++ b/tests/test_stubs.py @@ -3,11 +3,10 @@ import libcst as cst import libcst.matchers as cstm -from docstub._stubs import _get_docstring_node +from docstub._stubs import Py2StubTransformer, _get_docstring_node class Test_get_docstring_node: - def test_func(self): docstring = dedent( ''' @@ -72,3 +71,27 @@ def foo(a, b=None): docstring_node = _get_docstring_node(func_def) assert docstring_node is None + + +class Test_Py2StubTransformer: + + def test_default_None(self): + # Appending `| None` if a doctype is marked as "optional" + # is only correct if the default is actually None + # Ref: https://github.com/scientific-python/docstub/issues/13 + source = dedent( + ''' + def foo(a=None, b=1): + """ + Parameters + ---------- + a : int, optional + b : int, optional + """ + ''' + ) + expected = "def foo(a: int | None = ..., b: int = ...) -> None: ..." + + transformer = Py2StubTransformer() + result = transformer.python_to_stub(source) + assert expected in result diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..0d2937c --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,25 @@ +from docstub._utils import module_name_from_path + + +class Test_module_name_from_path: + def test_basic(self, tmp_path): + # Package structure + structure = [ + "foo/", + "foo/__init__.py", + "foo/bar.py", + "foo/baz/", + "foo/baz/__init__.py", + "foo/baz/qux.py", + ] + for item in structure: + path = tmp_path / item + if item.endswith(".py"): + path.touch() + else: + path.mkdir() + + assert module_name_from_path(tmp_path / "foo/__init__.py") == "foo" + assert module_name_from_path(tmp_path / "foo/bar.py") == "foo.bar" + assert module_name_from_path(tmp_path / "foo/baz/__init__.py") == "foo.baz" + assert module_name_from_path(tmp_path / "foo/baz/qux.py") == "foo.baz.qux"