Skip to content

Commit 15ab019

Browse files
committed
Add endian keyword argument to type reads
1 parent 194b1b5 commit 15ab019

22 files changed

+519
-242
lines changed

dissect/cstruct/bitbuffer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
class BitBuffer:
1010
"""Implements a bit buffer that can read and write bit fields."""
1111

12-
def __init__(self, stream: BinaryIO, endian: str):
12+
def __init__(self, stream: BinaryIO, *, endian: str, **kwargs):
1313
self.stream = stream
1414
self.endian = endian
15+
self.kwargs = kwargs
1516

1617
self._type: type[BaseType] | None = None
1718
self._buffer = 0
@@ -24,7 +25,7 @@ def read(self, field_type: type[BaseType], bits: int) -> int:
2425

2526
self._type = field_type
2627
self._remaining = field_type.size * 8
27-
self._buffer = field_type._read(self.stream)
28+
self._buffer = field_type._read(self.stream, endian=self.endian, **self.kwargs)
2829

2930
if isinstance(self._buffer, bytes):
3031
if self.endian == "<":
@@ -71,7 +72,7 @@ def write(self, field_type: type[BaseType], data: int, bits: int) -> None:
7172

7273
def flush(self) -> None:
7374
if self._type is not None:
74-
self._type._write(self.stream, self._buffer)
75+
self._type._write(self.stream, self._buffer, endian=self.endian, **self.kwargs)
7576
self._type = None
7677
self._remaining = 0
7778
self._buffer = 0

dissect/cstruct/compiler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def generate_source(self) -> str:
117117
"""
118118

119119
if any(field.bits for field in self.fields):
120-
preamble += "bit_reader = BitBuffer(stream, cls.cs.endian)\n"
120+
preamble += "bit_reader = BitBuffer(stream, endian=endian, **kwargs)\n"
121121

122122
read_code = "\n".join(self._generate_fields())
123123

@@ -130,7 +130,7 @@ def generate_source(self) -> str:
130130

131131
code = indent(dedent(preamble).lstrip() + read_code + dedent(outro), " ")
132132

133-
return f"def _read(cls, stream, context=None):\n{code}"
133+
return f"def _read(cls, stream, *, context=None, endian, **kwargs):\n{code}"
134134

135135
def _generate_fields(self) -> Iterator[str]:
136136
current_offset = 0
@@ -227,7 +227,7 @@ def align_to_field(field: Field) -> Iterator[str]:
227227
def _generate_structure(self, field: Field) -> Iterator[str]:
228228
template = f"""
229229
{"_s = stream.tell()" if field.type.dynamic else ""}
230-
r["{field._name}"] = {self._map_field(field)}._read(stream, context=r)
230+
r["{field._name}"] = {self._map_field(field)}._read(stream, context=r, endian=endian, **kwargs)
231231
{f's["{field._name}"] = stream.tell() - _s' if field.type.dynamic else ""}
232232
"""
233233

@@ -236,7 +236,7 @@ def _generate_structure(self, field: Field) -> Iterator[str]:
236236
def _generate_array(self, field: Field) -> Iterator[str]:
237237
template = f"""
238238
{"_s = stream.tell()" if field.type.dynamic else ""}
239-
r["{field._name}"] = {self._map_field(field)}._read(stream, context=r)
239+
r["{field._name}"] = {self._map_field(field)}._read(stream, context=r, endian=endian, **kwargs)
240240
{f's["{field._name}"] = stream.tell() - _s' if field.type.dynamic else ""}
241241
"""
242242

@@ -309,7 +309,7 @@ def _generate_packed(self, fields: list[Field]) -> Iterator[str]:
309309
item_parser = parser_template.format(type="_et", getter=f"_b[i:i + {field_type.type.size}]")
310310
list_comp = f"[{item_parser} for i in range(0, {count}, {field_type.type.size})]"
311311
elif issubclass(field_type.type, Pointer):
312-
item_parser = "_et.__new__(_et, e, stream, r)"
312+
item_parser = "_et.__new__(_et, e, stream, context=r, endian=endian, **kwargs)"
313313
list_comp = f"[{item_parser} for e in {getter}]"
314314
else:
315315
item_parser = parser_template.format(type="_et", getter="e")
@@ -320,7 +320,7 @@ def _generate_packed(self, fields: list[Field]) -> Iterator[str]:
320320
parser = f"type.__call__({self._map_field(field)}, {getter})"
321321
elif issubclass(field_type, Pointer):
322322
reads.append(f"_pt = {self._map_field(field)}")
323-
parser = f"_pt.__new__(_pt, {getter}, stream, r)"
323+
parser = f"_pt.__new__(_pt, {getter}, stream, context=r, endian=endian, **kwargs)"
324324
else:
325325
parser = parser_template.format(type=self._map_field(field), getter=getter)
326326

@@ -333,7 +333,7 @@ def _generate_packed(self, fields: list[Field]) -> Iterator[str]:
333333
if fmt == "x" or (len(fmt) == 2 and fmt[1] == "x"):
334334
unpack = ""
335335
else:
336-
unpack = f'data = _struct(cls.cs.endian, "{fmt}").unpack(buf)\n'
336+
unpack = f'data = _struct(endian, "{fmt}").unpack(buf)\n'
337337

338338
template = f"""
339339
buf = stream.read({size})

dissect/cstruct/cstruct.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

33
import ctypes as _ctypes
4+
import inspect
45
import struct
56
import sys
67
import types
8+
import warnings
79
from pathlib import Path
8-
from typing import TYPE_CHECKING, Any, BinaryIO, TypeVar, cast
10+
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, TypeVar, cast
911

10-
from dissect.cstruct.exceptions import ResolveError
12+
from dissect.cstruct.exceptions import Error, ResolveError
1113
from dissect.cstruct.expression import Expression
1214
from dissect.cstruct.parser import CStyleParser, TokenParser
1315
from dissect.cstruct.types import (
@@ -27,6 +29,7 @@
2729
Void,
2830
Wchar,
2931
)
32+
from dissect.cstruct.types.base import normalize_endianness
3033

3134
if TYPE_CHECKING:
3235
from collections.abc import Iterable
@@ -35,20 +38,23 @@
3538

3639
T = TypeVar("T", bound=BaseType)
3740

41+
AllowedEndianness: TypeAlias = Literal["little", "big", "network", "<", ">", "!", "@", "="]
42+
Endianness: TypeAlias = Literal["<", ">", "!", "@", "="]
43+
3844

3945
class cstruct:
4046
"""Main class of cstruct. All types are registered in here.
4147
4248
Args:
43-
endian: The endianness to use when parsing.
49+
endian: The endianness to use when parsing (little, big, network, <, >, !, @ or =).
4450
pointer: The pointer type to use for pointers.
4551
"""
4652

4753
DEF_CSTYLE = 1
4854
DEF_LEGACY = 2
4955

50-
def __init__(self, load: str = "", *, endian: str = "<", pointer: str | None = None):
51-
self.endian = endian
56+
def __init__(self, load: str = "", *, endian: AllowedEndianness = "<", pointer: str | None = None):
57+
self.endian = normalize_endianness(endian)
5258

5359
self.consts = {}
5460
self.lookups = {}
@@ -242,6 +248,33 @@ def add_custom_type(
242248
alignment: The alignment of the type.
243249
**kwargs: Additional attributes to add to the type.
244250
"""
251+
# In cstruct 4.8 we changed the function signature of _read and _write
252+
# Check if the function signature is compatible, and throw an error if not
253+
for type_to_check in (type_, type_.ArrayType):
254+
type_name = type_.__name__ + (f".{type_.ArrayType.__name__}" if type_to_check is type_.ArrayType else "")
255+
256+
for method in ("_read", "_read_array", "_read_0", "_write", "_write_array", "_write_0"):
257+
if not hasattr(type_to_check, method):
258+
continue
259+
260+
signature = inspect.signature(getattr(type_to_check, method))
261+
262+
# We added a few keyword-only parameters to the function signature, but any custom type will
263+
# continue to work fine as long as they accept **kwargs
264+
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
265+
raise Error(
266+
f"Custom type {type_name} has an incompatible {method} method signature. "
267+
"Please refer to the changelog of dissect.cstruct 4.8 for more information."
268+
)
269+
270+
# Only warn if the method doesn't accept an endian parameter
271+
if "endian" not in signature.parameters:
272+
warnings.warn(
273+
f"Custom type {type_name} is missing the 'endian' keyword-only parameter in its {method} method. " # noqa: E501
274+
"Please refer to the changelog of dissect.cstruct 4.8 for more information.",
275+
stacklevel=2,
276+
)
277+
245278
self.add_type(name, self._make_type(name, (type_,), size, alignment=alignment, attrs=kwargs))
246279

247280
def load(self, definition: str, deftype: int | None = None, **kwargs) -> cstruct:

0 commit comments

Comments
 (0)