Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/example_pkg-stubs/_basic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def func_contains(
def func_literals(
a1: Literal[1, 3, "foo"], a2: Literal["uno", 2, "drei", "four"] = ...
) -> None: ...
def override_docstring_param(
d1: dict[str, float], d2: dict[Literal["a", "b", "c"], int]
) -> None: ...
def override_docstring_return() -> list[Literal[-1, 0, 1] | float]: ...
def func_use_from_elsewhere(
a1: CustomException,
a2: ExampleClass,
Expand All @@ -37,6 +41,9 @@ def func_use_from_elsewhere(
) -> tuple[CustomException, ExampleClass.NestedClass]: ...

class ExampleClass:

b1: int

class NestedClass:
def method_in_nested_class(self, a1: complex) -> None: ...

Expand Down
27 changes: 26 additions & 1 deletion examples/example_pkg/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# Existing imports are preserved
import logging
from typing import Literal

# Assign-statements are preserved
logger = logging.getLogger(__name__) # Inline comments are stripped
Expand Down Expand Up @@ -51,6 +52,25 @@ def func_literals(a1, a2="uno"):
"""


def override_docstring_param(d1, d2: dict[Literal["a", "b", "c"], int]):
"""Check type hint is kept and overrides docstring.

Parameters
----------
d1 : dict of {str : float}
d2 : dict of {str : int}
"""


def override_docstring_return() -> list[Literal[-1, 0, 1] | float]:
"""Check type hint is kept and overrides docstring.

Returns
-------
{"-inf", 0, 1, "inf"}
"""


def func_use_from_elsewhere(a1, a2, a3, a4):
"""Check if types with full import names are matched.

Expand All @@ -75,10 +95,15 @@ class ExampleClass:
----------
a1 : str
a2 : float, default 0

Attributes
----------
b1 : Sized
"""

class NestedClass:
b1: int

class NestedClass:
def method_in_nested_class(self, a1):
"""

Expand Down
75 changes: 54 additions & 21 deletions src/docstub/_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,25 +571,31 @@ def leave_FunctionDef(self, original_node, updated_node):
assert ds_annotations.returns.value
annotation_value = ds_annotations.returns.value

if original_node.returns is not None:
if original_node.returns is None:
annotation = cst.Annotation(cst.parse_expression(annotation_value))
node_changes["returns"] = annotation
# TODO: check imports
self._required_imports |= ds_annotations.returns.imports

else:
# Notify about ignored docstring annotation
# TODO: either remove message or print only in verbose mode
position = self.get_metadata(
cst.metadata.PositionProvider, original_node
).start
reporter = self.reporter.copy_with(
path=self.current_source, line=position.line
)
replaced = _inline_node_as_code(original_node.returns.annotation)
to_keep = _inline_node_as_code(original_node.returns.annotation)
details = (
f"{replaced}\n{reporter.underline(replaced)} -> {annotation_value}"
f"{reporter.underline(to_keep)} "
f"ignoring docstring: {annotation_value}"
)
reporter.message(
short="Replacing existing inline return annotation",
short="Keeping existing inline return annotation",
details=details,
)

annotation = cst.Annotation(cst.parse_expression(annotation_value))
node_changes["returns"] = annotation
self._required_imports |= ds_annotations.returns.imports
elif original_node.returns is None:
annotation = cst.Annotation(cst.parse_expression("None"))
node_changes["returns"] = annotation
Expand Down Expand Up @@ -633,10 +639,35 @@ def leave_Param(self, original_node, updated_node):
if pytype:
if defaults_to_none:
pytype = pytype.as_optional()
annotation = cst.Annotation(cst.parse_expression(pytype.value))
node_changes["annotation"] = annotation
if pytype.imports:
self._required_imports |= pytype.imports
annotation_value = pytype.value

if original_node.annotation is None:
annotation = cst.Annotation(cst.parse_expression(annotation_value))
node_changes["annotation"] = annotation
# TODO: check imports
if pytype.imports:
self._required_imports |= pytype.imports

else:
# Notify about ignored docstring annotation
# TODO: either remove message or print only in verbose mode
position = self.get_metadata(
cst.metadata.PositionProvider, original_node
).start
reporter = self.reporter.copy_with(
path=self.current_source, line=position.line
)
to_keep = cst.Module([]).code_for_node(
original_node.annotation.annotation
)
details = (
f"{reporter.underline(to_keep)} "
f"ignoring docstring: {annotation_value}"
)
reporter.message(
short="Keeping existing inline parameter annotation",
details=details,
)

# Potentially use "Incomplete" except for first param in (class)methods
elif not is_self_or_cls and updated_node.annotation is None:
Expand Down Expand Up @@ -764,31 +795,33 @@ def leave_AnnAssign(self, original_node, updated_node):
if pytypes and name in pytypes.attributes:
pytype = pytypes.attributes[name]
expr = cst.parse_expression(pytype.value)
self._required_imports |= pytype.imports

if updated_node.annotation is not None:
# Turn original annotation into str and print with context
if updated_node.annotation is None:
self._required_imports |= pytype.imports
updated_node = updated_node.with_deep_changes(
updated_node.annotation, annotation=expr
)

else:
# Notify about ignored docstring annotation
# TODO: either remove message or print only in verbose mode
position = self.get_metadata(
cst.metadata.PositionProvider, original_node
).start
reporter = self.reporter.copy_with(
path=self.current_source, line=position.line
)
replaced = cst.Module([]).code_for_node(
to_keep = cst.Module([]).code_for_node(
updated_node.annotation.annotation
)
details = (
f"{replaced}\n{reporter.underline(replaced)} -> {pytype.value}"
f"{reporter.underline(to_keep)} ignoring docstring: {pytype.value}"
)
reporter.message(
short="Replacing existing inline annotation",
short="Keeping existing inline annotation for assignment",
details=details,
)

updated_node = updated_node.with_deep_changes(
updated_node.annotation, annotation=expr
)

return updated_node

def visit_Module(self, node):
Expand Down
146 changes: 133 additions & 13 deletions tests/test_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,14 @@ def test_attributes_no_doctype(self, assign, expected, scope):
@pytest.mark.parametrize(
("assign", "doctype", "expected"),
[
# ("plain = 3", "plain : int", "plain: int"),
# ("plain = None", "plain : int", "plain: int"),
# ("x, y = (1, 2)", "x : int", "x: int; y: Incomplete"),
# Replace pre-existing annotations
("annotated: float = 1.0", "annotated : int", "annotated: int"),
("plain = 3", "plain : int", "plain: int"),
("plain = None", "plain : int", "plain: int"),
("x, y = (1, 2)", "x : int", "x: int; y: Incomplete"),
# Keep pre-existing annotations
("annotated: float = 1.0", "annotated : int", "annotated: float"),
# Type aliases are untouched
# ("alias: TypeAlias = int", "alias : str", "alias: TypeAlias = int"),
# ("type alias = int", "alias : str", "type alias = int"),
("alias: TypeAlias = int", "alias : str", "alias: TypeAlias = int"),
("type alias = int", "alias : str", "type alias = int"),
],
)
@pytest.mark.parametrize("scope", ["module", "class", "nested class"])
Expand Down Expand Up @@ -283,7 +283,7 @@ class Foo:
a: int
b: float

c: tuple
c: list
d: ClassVar[bool]

def __init__(self, a) -> None: ...
Expand All @@ -298,7 +298,127 @@ def test_undocumented_objects(self):
# https://typing.readthedocs.io/en/latest/guides/writing_stubs.html#undocumented-objects
pass

def test_existing_typed_return(self):
def test_keep_assign_param(self):
source = dedent(
"""
a: str
"""
)
expected = dedent(
"""
a: str
"""
)
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result

def test_keep_inline_assign_with_doctype(self, capsys):
source = dedent(
'''
"""
Attributes
----------
a : Sized
"""
a: str
'''
)
expected = dedent(
"""
a: str
"""
)
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result

captured = capsys.readouterr()
assert "Keeping existing inline annotation for assignment" in captured.out

def test_keep_class_assign_param(self):
source = dedent(
"""
class Foo:
a: str
"""
)
expected = dedent(
"""
class Foo:
a: str
"""
)
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result

def test_keep_inline_class_assign_with_doctype(self, capsys):
source = dedent(
'''
class Foo:
"""
Attributes
----------
a : Sized
"""
a: str
'''
)
expected = dedent(
"""
class Foo:
a: str
"""
)
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result

captured = capsys.readouterr()
assert "Keeping existing inline annotation for assignment" in captured.out

def test_keep_inline_param(self):
source = dedent(
"""
def foo(a: str) -> None:
pass
"""
)
expected = dedent(
"""
def foo(a: str) -> None: ...
"""
)
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result

def test_keep_inline_param_with_doctype(self, capsys):
source = dedent(
'''
def foo(a: int) -> None:
"""
Parameters
----------
a : Sized
"""
pass
'''
)
expected = dedent(
"""
def foo(a: int) -> None: ...
"""
)
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result

captured = capsys.readouterr()
assert "Keeping existing inline parameter annotation" in captured.out

def test_keep_inline_return(self):
source = dedent(
"""
def foo() -> str:
Expand All @@ -314,14 +434,14 @@ def foo() -> str: ...
result = transformer.python_to_stub(source)
assert expected == result

def test_overwriting_typed_return(self, capsys):
def test_keep_inline_return_with_doctype(self, capsys):
source = dedent(
'''
def foo() -> dict[str, int]:
def foo() -> int:
"""
Returns
-------
out : int
out : Sized
"""
pass
'''
Expand All @@ -336,7 +456,7 @@ def foo() -> int: ...
assert expected == result

captured = capsys.readouterr()
assert "Replacing existing inline return annotation" in captured.out
assert "Keeping existing inline return annotation" in captured.out

def test_preserved_type_comment(self):
source = dedent(
Expand Down