Skip to content

Commit ab0863e

Browse files
authored
Fix typedefs for internal types and add literals to cstruct-stubgen (#118)
1 parent 2169211 commit ab0863e

File tree

3 files changed

+90
-20
lines changed

3 files changed

+90
-20
lines changed

dissect/cstruct/cstruct.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,10 +425,7 @@ def _make_union(
425425
) -> type[Structure]:
426426
return self._make_struct(name, fields, align=align, anonymous=anonymous, base=Union)
427427

428-
Z = TYPE_CHECKING
429-
430428
if TYPE_CHECKING:
431-
A = 1
432429
# ruff: noqa: PYI042
433430
_int = int
434431
_float = float

dissect/cstruct/tools/stubgen.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,29 @@ def generate_file_stub(path: Path, base: Path) -> str:
3939

4040
header = [
4141
"# Generated by cstruct-stubgen",
42-
"from typing import BinaryIO, overload",
42+
"from typing import BinaryIO, Literal, overload",
4343
"",
4444
"import dissect.cstruct as __cs__",
45+
"from typing_extensions import TypeAlias",
4546
]
4647
body = []
4748

4849
for name, obj in tmp_module.__dict__.items():
4950
if isinstance(obj, cstruct):
5051
stub = generate_cstruct_stub(obj, module_prefix="__cs__.", cls_name=f"_{name}")
5152
body.append(stub)
52-
body.append(f"{name}: _{name}")
53+
54+
if body[-1][-1] != "\n":
55+
body.append("")
56+
57+
body.append(f"# Technically `{name}` is an instance of `_{name}`, but then we can't use it in type hints")
58+
body.append(f"{name}: TypeAlias = _{name}")
59+
body.append("")
5360

5461
if not body:
5562
return ""
5663

57-
body_str = "\n".join(body)
58-
if "TypeAlias" in body_str:
59-
header.append("from typing_extensions import TypeAlias")
60-
61-
return "\n".join([*header, "", body_str, ""])
64+
return "\n".join([*header, "", "\n".join(body)])
6265

6366

6467
def generate_cstruct_stub(cs: cstruct, module_prefix: str = "", cls_name: str = "cstruct") -> str:
@@ -73,7 +76,7 @@ def generate_cstruct_stub(cs: cstruct, module_prefix: str = "", cls_name: str =
7376
for name, value in cs.consts.items():
7477
if name in empty_cs.consts:
7578
continue
76-
body.append(textwrap.indent(f"{name}: {type(value).__name__} = ...", prefix=indent))
79+
body.append(textwrap.indent(f"{name}: Literal[{value!r}] = ...", prefix=indent))
7780

7881
defined_names = set()
7982

@@ -82,10 +85,11 @@ def generate_cstruct_stub(cs: cstruct, module_prefix: str = "", cls_name: str =
8285
if name in empty_cs.typedefs:
8386
continue
8487

85-
if typedef.__name__ in defined_names:
88+
if typedef.__name__ in empty_cs.typedefs:
89+
stub = f"{name}: TypeAlias = {cs_prefix}{typedef.__name__}"
90+
elif typedef.__name__ in defined_names:
8691
# Create an alias to the type if we have already seen it before.
8792
stub = f"{name}: TypeAlias = {typedef.__name__}"
88-
8993
elif issubclass(typedef, (types.Enum, types.Flag)):
9094
stub = generate_enum_stub(typedef, cs_prefix=cs_prefix, module_prefix=module_prefix)
9195
elif issubclass(typedef, types.Structure):
@@ -129,7 +133,7 @@ def generate_generic_stub(
129133
cs_prefix: str = "",
130134
module_prefix: str = "",
131135
) -> str:
132-
return f"class {name_prefix}{type_.__name__}({module_prefix}{type_.__base__.__name__}): ..."
136+
return f"class {name_prefix}{type_.__name__}({module_prefix}{type_.__base__.__name__}): ...\n"
133137

134138

135139
def generate_enum_stub(
@@ -140,6 +144,7 @@ def generate_enum_stub(
140144
) -> str:
141145
result = [f"class {name_prefix}{enum.__name__}({module_prefix}{enum.__base__.__name__}):"]
142146
result.extend(f" {key} = ..." for key in enum.__members__)
147+
result.append("")
143148

144149
return "\n".join(result)
145150

@@ -181,6 +186,7 @@ def generate_structure_stub(
181186
result.append(
182187
textwrap.indent("def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...", prefix=indent)
183188
)
189+
result.append("")
184190
return "\n".join(result)
185191

186192

tests/test_tools_stubgen.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
class Test(Enum):
2828
A = ...
2929
B = ...
30+
3031
""",
3132
id="enum",
3233
),
@@ -41,6 +42,7 @@ class Test(Enum):
4142
class Test(Enum):
4243
A = ...
4344
B = ...
45+
4446
""",
4547
id="enum int8",
4648
),
@@ -55,6 +57,7 @@ class Test(Enum):
5557
class Test(Flag):
5658
A = ...
5759
B = ...
60+
5861
""",
5962
id="flag",
6063
),
@@ -63,7 +66,9 @@ class Test(Flag):
6366
def test_generate_enum_stub(cs: cstruct, cdef: str, expected: str) -> None:
6467
cs.load(cdef)
6568

66-
assert stubgen.generate_enum_stub(cs.Test) == textwrap.dedent(expected).strip()
69+
# We don't want to strip all trailing whitespace in case it's part of the intended expected output
70+
# So just remove one newline from the final """ block
71+
assert stubgen.generate_enum_stub(cs.Test) == textwrap.dedent(expected).lstrip()[:-1]
6772

6873

6974
@pytest.mark.parametrize(
@@ -86,6 +91,7 @@ class Test(Structure):
8691
def __init__(self, a: uint8 | None = ..., b: uint8 | None = ..., c: uint16 | None = ...): ...
8792
@overload
8893
def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...
94+
8995
""",
9096
id="basic",
9197
),
@@ -106,6 +112,7 @@ class Test(Structure):
106112
def __init__(self, a: Array[uint8] | None = ..., b: CharArray | None = ..., c: WcharArray | None = ...): ...
107113
@overload
108114
def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...
115+
109116
""", # noqa: E501
110117
id="array",
111118
),
@@ -124,6 +131,7 @@ class Test(Structure):
124131
def __init__(self, a: Pointer[uint8] | None = ..., b: Array[Pointer[uint8]] | None = ...): ...
125132
@overload
126133
def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...
134+
127135
""",
128136
id="pointer",
129137
),
@@ -144,6 +152,7 @@ class Test(Structure):
144152
def __init__(self, a: uint8 | None = ..., b: uint8 | None = ...): ...
145153
@overload
146154
def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...
155+
147156
""",
148157
id="anonymous nested",
149158
),
@@ -165,11 +174,13 @@ class __anonymous_0__(Union):
165174
def __init__(self, a: uint8 | None = ..., b: uint8 | None = ...): ...
166175
@overload
167176
def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...
177+
168178
x: __anonymous_0__
169179
@overload
170180
def __init__(self, x: __anonymous_0__ | None = ...): ...
171181
@overload
172182
def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...
183+
173184
""",
174185
id="named nested",
175186
),
@@ -191,11 +202,13 @@ class __anonymous_0__(Structure):
191202
def __init__(self, a: uint8 | None = ..., b: uint8 | None = ...): ...
192203
@overload
193204
def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...
205+
194206
x: Array[__anonymous_0__]
195207
@overload
196208
def __init__(self, x: Array[__anonymous_0__] | None = ...): ...
197209
@overload
198210
def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...
211+
199212
""",
200213
id="named nested array",
201214
),
@@ -204,7 +217,9 @@ def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...
204217
def test_generate_structure_stub(cs: cstruct, cdef: str, expected: str) -> None:
205218
cs.load(cdef)
206219

207-
assert stubgen.generate_structure_stub(cs.Test) == textwrap.dedent(expected).strip()
220+
# We don't want to strip all trailing whitespace in case it's part of the intended expected output
221+
# So just remove one newline from the final """ block
222+
assert stubgen.generate_structure_stub(cs.Test) == textwrap.dedent(expected).lstrip()[:-1]
208223

209224

210225
@pytest.mark.parametrize(
@@ -225,10 +240,11 @@ def test_generate_structure_stub(cs: cstruct, cdef: str, expected: str) -> None:
225240
""",
226241
"""
227242
class cstruct(cstruct):
228-
TEST: int = ...
243+
TEST: Literal[1] = ...
229244
class TestEnum(Enum):
230245
A = ...
231246
B = ...
247+
232248
class TestStruct(Structure):
233249
a: cstruct.uint8
234250
@overload
@@ -253,16 +269,64 @@ class Test(Structure):
253269
def __init__(self, a: cstruct.uint8 | None = ...): ...
254270
@overload
255271
def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...
272+
256273
_test: TypeAlias = Test
257274
""",
258275
id="alias stub",
259276
),
277+
pytest.param(
278+
"""
279+
typedef __u16 __fs16;
280+
typedef __u32 __fs32;
281+
typedef __u64 __fs64;
282+
283+
struct Test {
284+
__fs16 a;
285+
__fs32 b;
286+
__fs64 c;
287+
};
288+
""",
289+
"""
290+
class cstruct(cstruct):
291+
__fs16: TypeAlias = cstruct.uint16
292+
__fs32: TypeAlias = cstruct.uint32
293+
__fs64: TypeAlias = cstruct.uint64
294+
class Test(Structure):
295+
a: cstruct.uint16
296+
b: cstruct.uint32
297+
c: cstruct.uint64
298+
@overload
299+
def __init__(self, a: cstruct.uint16 | None = ..., b: cstruct.uint32 | None = ..., c: cstruct.uint64 | None = ...): ...
300+
@overload
301+
def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...
302+
303+
""", # noqa: E501
304+
id="typedef stub",
305+
),
306+
pytest.param(
307+
"""
308+
#define INT 1
309+
#define FLOAT 2.0
310+
#define STRING "hello"
311+
#define BYTES b'c'
312+
""",
313+
"""
314+
class cstruct(cstruct):
315+
INT: Literal[1] = ...
316+
FLOAT: Literal[2.0] = ...
317+
STRING: Literal['hello'] = ...
318+
BYTES: Literal[b'c'] = ...
319+
""",
320+
id="define literals",
321+
),
260322
],
261323
)
262324
def test_generate_cstruct_stub(cs: cstruct, cdef: str, expected: str) -> None:
263325
cs.load(cdef)
264326

265-
assert stubgen.generate_cstruct_stub(cs) == textwrap.dedent(expected).strip()
327+
# We don't want to strip all trailing whitespace in case it's part of the intended expected output
328+
# So just remove one newline from the final """ block
329+
assert stubgen.generate_cstruct_stub(cs) == textwrap.dedent(expected).lstrip()[:-1]
266330

267331

268332
def test_generate_cstruct_stub_empty(cs: cstruct) -> None:
@@ -292,9 +356,10 @@ def test_generate_file_stub(tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cap
292356

293357
expected = """
294358
# Generated by cstruct-stubgen
295-
from typing import BinaryIO, overload
359+
from typing import BinaryIO, Literal, overload
296360
297361
import dissect.cstruct as __cs__
362+
from typing_extensions import TypeAlias
298363
299364
class _c_structure(__cs__.cstruct):
300365
class Test(__cs__.Structure):
@@ -304,7 +369,9 @@ class Test(__cs__.Structure):
304369
def __init__(self, a: _c_structure.uint32 | None = ..., b: _c_structure.uint32 | None = ...): ...
305370
@overload
306371
def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...
307-
c_structure: _c_structure
372+
373+
# Technically `c_structure` is an instance of `_c_structure`, but then we can't use it in type hints
374+
c_structure: TypeAlias = _c_structure
308375
"""
309376

310377
assert stubgen.generate_file_stub(test_file, tmp_path) == textwrap.dedent(expected).lstrip()

0 commit comments

Comments
 (0)