Skip to content

Commit 675481a

Browse files
committed
WIP support dataclasses
1 parent 6cad9c5 commit 675481a

3 files changed

Lines changed: 66 additions & 15 deletions

File tree

src/docstub/_stubs.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,25 +124,36 @@ class _Scope:
124124
""""""
125125

126126
type: ScopeType
127-
node: cst.CSTNode = None
127+
node: cst.CSTNode | None = None
128128

129129
@property
130-
def has_self_or_cls(self):
130+
def has_self_or_cls(self) -> bool:
131131
return self.type in {ScopeType.METHOD, ScopeType.CLASSMETHOD}
132132

133133
@property
134-
def is_method(self):
134+
def is_method(self) -> bool:
135135
return self.type in {
136136
ScopeType.METHOD,
137137
ScopeType.CLASSMETHOD,
138138
ScopeType.STATICMETHOD,
139139
}
140140

141141
@property
142-
def is_class_init(self):
142+
def is_class_init(self) -> bool:
143143
out = self.is_method and self.node.name.value == "__init__"
144144
return out
145145

146+
@property
147+
def is_dataclass(self) -> bool:
148+
if cstm.matches(self.node, cstm.ClassDef()):
149+
# Determine if dataclass
150+
decorators = cstm.findall(self.node, cstm.Decorator())
151+
is_dataclass = any(
152+
cstm.findall(d, cstm.Name("dataclass")) for d in decorators
153+
)
154+
return is_dataclass
155+
return False
156+
146157

147158
def _get_docstring_node(node):
148159
"""Extract the node with the docstring from a definition.
@@ -672,16 +683,27 @@ def leave_AnnAssign(self, original_node, updated_node):
672683
updated_node : cst.AnnAssign
673684
"""
674685
name = updated_node.target.value
675-
is_type_alias = cstm.matches(
676-
updated_node.annotation, cstm.Annotation(cstm.Name("TypeAlias"))
677-
)
678-
is__all__ = cstm.matches(updated_node.target, cstm.Name("__all__"))
679686

680-
# Remove value if not type alias or __all__
681-
if updated_node.value is not None and not is_type_alias and not is__all__:
682-
updated_node = updated_node.with_changes(
683-
value=None, equal=cst.MaybeSentinel.DEFAULT
687+
if updated_node.value is not None:
688+
is_type_alias = cstm.matches(
689+
updated_node.annotation, cstm.Annotation(cstm.Name("TypeAlias"))
684690
)
691+
is__all__ = cstm.matches(updated_node.target, cstm.Name("__all__"))
692+
is_dataclass = self._scope_stack[-1].is_dataclass
693+
is_classvar = any(
694+
cstm.findall(updated_node.annotation, cstm.Name("ClassVar"))
695+
)
696+
697+
# Replace with ellipses if dataclass
698+
if is_dataclass and not is_classvar:
699+
updated_node = updated_node.with_changes(
700+
value=cst.Ellipsis(), equal=cst.MaybeSentinel.DEFAULT
701+
)
702+
# Remove value if not type alias or __all__
703+
elif not is_type_alias and not is__all__:
704+
updated_node = updated_node.with_changes(
705+
value=None, equal=cst.MaybeSentinel.DEFAULT
706+
)
685707

686708
# Replace with type annotation from docstring, if available
687709
pytypes = self._pytypes_stack[-1]

stubtest_allow.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,3 @@ docstub\._version\..*
22
docstub\..*\.__match_args__$
33
docstub._cache.FuncSerializer.__type_params__
44
docstub._cli.main
5-
docstub._config.Config.__init__
6-
docstub._docstrings.Annotation.__init__
7-
docstub._stubs._Scope.__init__

tests/test_stubs.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,3 +394,35 @@ class Foo:
394394
# remove these empty lines from the result too
395395
result = dedent(result)
396396
assert expected == result
397+
398+
@pytest.mark.parametrize("decorator", ["dataclass", "dataclasses.dataclass"])
399+
def test_dataclass(self, decorator):
400+
source = dedent(
401+
f"""
402+
@{decorator}
403+
class Foo:
404+
a: float
405+
b: int = 3
406+
c: str = None
407+
d: dict[str, Any] = field(default_factory=dict)
408+
e: ClassVar
409+
f: ClassVar[float]
410+
g: Final[ClassVar[int]] = 1
411+
"""
412+
)
413+
expected = dedent(
414+
f"""
415+
@{decorator}
416+
class Foo:
417+
a: float
418+
b: int = ...
419+
c: str = ...
420+
d: dict[str, Any] = ...
421+
e: ClassVar
422+
f: ClassVar[float]
423+
g: Final[ClassVar[int]]
424+
"""
425+
)
426+
transformer = Py2StubTransformer()
427+
result = transformer.python_to_stub(source)
428+
assert expected == result

0 commit comments

Comments
 (0)