Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/example_pkg-stubs/_basic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ class ExampleClass:
def method_in_nested_class(self, a1: complex) -> None: ...

def __init__(self, a1: str, a2: float = ...) -> None: ...
def method(self, a1: float, a2: float | None) -> list[float]: ...
def method(
self, a1: float, a2: float = ..., a3: float | None = ...
) -> list[float]: ...
@staticmethod
def some_staticmethod(a1: float, a2: float | None = ...) -> dict[str, Any]: ...
def some_staticmethod(a1: float, a2: str = ...) -> dict[str, Any]: ...
@property
def some_property(self) -> str: ...
@some_property.setter
Expand Down
2 changes: 1 addition & 1 deletion examples/example_pkg-stubs/_numpy.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def func_object_with_numpy_objects(
a1: np.int8, a2: np.int16, a3: np.typing.DTypeLike, a4: np.typing.DTypeLike
) -> None: ...
def func_ndarray(
a1: NDArray, a2: np.NDArray, a3: NDArray[float], a4: NDArray[np.uint8] = ...
a1: NDArray, a2: np.NDArray, a3: NDArray[float], a4: NDArray[np.uint8] | None = ...
) -> tuple[NDArray[np.uint8], NDArray[complex]]: ...
def func_array_like(
a1: ArrayLike, a2: ArrayLike, a3: ArrayLike[float], a4: ArrayLike[np.uint8]
Expand Down
5 changes: 3 additions & 2 deletions examples/example_pkg/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,14 @@ def method_in_nested_class(self, a1):
def __init__(self, a1, a2=0):
pass

def method(self, a1, a2):
def method(self, a1, a2=0, a3=None):
"""Dummy.

Parameters
----------
a1 : float
a2 : float, optional
a3 : float, optional

Returns
-------
Expand All @@ -110,7 +111,7 @@ def some_staticmethod(a1, a2="uno"):
Parameters
----------
a1 : float
a2 : float, optional
a2 : str, optional

Returns
-------
Expand Down
66 changes: 38 additions & 28 deletions src/docstub/_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import libcst as cst

from ._utils import accumulate_qualname
from ._utils import accumulate_qualname, module_name_from_path

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,16 +42,21 @@ def _shared_leading_path(*paths):
class KnownImport:
"""Import information associated with a single known type annotation.

Parameters
Attributes
----------
import_name :
Dotted names after "import".
import_path :
import_path : str, optional
Dotted names after "from".
import_alias :
import_name : str, optional
Dotted names after "import".
import_alias : str, optional
Name (without ".") after "as".
builtin_name :
builtin_name : str, optional
Names an object that's builtin and doesn't need an import.

Examples
--------
>>> KnownImport(import_path="numpy", import_name="uint8", import_alias="ui8")
<KnownImport 'from numpy import uint8 as ui8'>
"""

import_name: str = None
Expand Down Expand Up @@ -170,14 +175,6 @@ def __str__(self):
return out


@dataclass(slots=True, frozen=True)
class InspectionContext:
"""Currently inspected module and other information."""

file_path: Path
in_package_path: str


def _is_type(value) -> bool:
"""Check if value is a type."""
# Checking for isinstance(..., type) isn't enough, some types such as
Expand Down Expand Up @@ -262,45 +259,57 @@ def common_known_imports():
return known_imports


class KnownImportCollector(cst.CSTVisitor):
class TypeCollector(cst.CSTVisitor):
@classmethod
def collect(cls, file, module_name):
def collect(cls, file):
"""Collect importable type annotations in given file.

Parameters
----------
file : Path

Returns
-------
collected : dict[str, KnownImport]
"""
file = Path(file)
with file.open("r") as fo:
source = fo.read()

tree = cst.parse_module(source)
collector = cls(module_name=module_name)
collector = cls(module_name=module_name_from_path(file))
tree.visit(collector)
return collector.known_imports

def __init__(self, *, module_name):
"""Initialize type collector.

Parameters
----------
module_name : str
"""
self.module_name = module_name
self._stack = []
self.known_imports = {}

def visit_ClassDef(self, node):
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self._stack.append(node.name.value)

class_name = ".".join(self._stack[:1])
qualname = f"{self.module_name}.{'.'.join(self._stack)}"

known_import = KnownImport(
import_name=class_name,
import_path=self.module_name,
)
known_import = KnownImport(import_path=self.module_name, import_name=class_name)
self.known_imports[qualname] = known_import

return True

def leave_ClassDef(self, original_node):
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self._stack.pop()

def visit_FunctionDef(self, node):
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
self._stack.append(node.name.value)
return True

def leave_FunctionDef(self, original_node):
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self._stack.pop()


Expand Down Expand Up @@ -395,7 +404,8 @@ def query(self, search_name):

if known_import is None and self.current_source:
# Try scope of current module
try_qualname = f"{self.current_source.import_path}.{search_name}"
module_name = module_name_from_path(self.current_source)
try_qualname = f"{module_name}.{search_name}"
known_import = self.known_imports.get(try_qualname)
if known_import:
annotation_name = search_name
Expand Down
6 changes: 2 additions & 4 deletions src/docstub/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from ._analysis import (
KnownImport,
KnownImportCollector,
StaticInspector,
TypeCollector,
common_known_imports,
)
from ._config import Config
Expand Down Expand Up @@ -92,9 +92,7 @@ def main(source_dir, out_dir, config_path, verbose):
known_imports = common_known_imports()
for source_path in walk_source(source_dir):
logger.info("collecting types in %s", source_path)
known_imports_in_source = KnownImportCollector.collect(
source_path, module_name=source_path.import_path
)
known_imports_in_source = TypeCollector.collect(source_path)
known_imports.update(known_imports_in_source)
known_imports.update(KnownImport.many_from_config(config.known_imports))

Expand Down
36 changes: 22 additions & 14 deletions src/docstub/_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __post_init__(self):
object.__setattr__(self, "imports", frozenset(self.imports))
if "~" in self.value:
raise ValueError(f"unexpected '~' in annotation value: {self.value}")
for import_ in self.imports:
if not isinstance(import_, KnownImport):
raise TypeError(f"unexpected type {type(import_)} in `imports`")

def __str__(self) -> str:
return self.value
Expand Down Expand Up @@ -95,6 +98,22 @@ def as_yields_generator(cls, yield_types, receive_types=()):
# TODO
raise NotImplementedError()

def as_optional(self):
"""Return optional version of this annotation by appending `| None`.

Returns
-------
optional : Annotation

Examples
--------
>>> Annotation(value="int").as_optional()
Annotation(value='int | None', imports=frozenset())
"""
value = f"{self.value} | None"
optional = type(self)(value=value, imports=self.imports)
return optional

@staticmethod
def _aggregate_annotations(*types):
"""Aggregate values and imports of given Annotations.
Expand All @@ -118,14 +137,7 @@ def _aggregate_annotations(*types):

GrammarErrorFallback = Annotation(
value="Any",
imports=frozenset(
(
KnownImport(
import_name="Any",
import_path="typing",
),
)
),
imports=frozenset((KnownImport(import_path="typing", import_name="Any"),)),
)


Expand Down Expand Up @@ -233,12 +245,8 @@ def types_or(self, tree):
return out

def optional(self, tree):
out = "None"
literal = [child for child in tree.children if child.type == "LITERAL"]
assert len(literal) <= 1
if literal:
out = lark.Discard # Type should cover the default
return out
logger.debug("dropping optional / default info")
return lark.Discard

def extra_info(self, tree):
logger.debug("dropping extra info")
Expand Down
Loading