diff --git a/nibabel/cifti2/cifti2.py b/nibabel/cifti2/cifti2.py index cb2e0cfaf..b2b67978b 100644 --- a/nibabel/cifti2/cifti2.py +++ b/nibabel/cifti2/cifti2.py @@ -1570,7 +1570,7 @@ def to_file_map(self, file_map=None, dtype=None): self.update_headers() header = self._nifti_header - extension = Cifti2Extension(content=self.header.to_xml()) + extension = Cifti2Extension.from_bytes(self.header.to_xml()) header.extensions = Nifti1Extensions( ext for ext in header.extensions if not isinstance(ext, Cifti2Extension) ) diff --git a/nibabel/cifti2/parse_cifti2.py b/nibabel/cifti2/parse_cifti2.py index 48c2e0653..764e3ae20 100644 --- a/nibabel/cifti2/parse_cifti2.py +++ b/nibabel/cifti2/parse_cifti2.py @@ -40,19 +40,15 @@ ) -class Cifti2Extension(Nifti1Extension): +class Cifti2Extension(Nifti1Extension[Cifti2Header]): code = 32 - def __init__(self, code=None, content=None): - Nifti1Extension.__init__(self, code=code or self.code, content=content) - - def _unmangle(self, value): + def _unmangle(self, value: bytes) -> Cifti2Header: parser = Cifti2Parser() parser.parse(string=value) - self._content = parser.header - return self._content + return parser.header - def _mangle(self, value): + def _mangle(self, value: Cifti2Header) -> bytes: if not isinstance(value, Cifti2Header): raise ValueError('Can only mangle a Cifti2Header.') return value.to_xml() diff --git a/nibabel/nifti1.py b/nibabel/nifti1.py index ecd94c10d..626d21752 100644 --- a/nibabel/nifti1.py +++ b/nibabel/nifti1.py @@ -13,12 +13,20 @@ from __future__ import annotations +import json +import sys +import typing as ty import warnings from io import BytesIO import numpy as np import numpy.linalg as npl +if sys.version_info < (3, 13): + from typing_extensions import Self, TypeVar # PY312 +else: + from typing import Self, TypeVar + from . import analyze # module import from .arrayproxy import get_obj_dtype from .batteryrunners import Report @@ -31,7 +39,19 @@ from .spm99analyze import SpmAnalyzeHeader from .volumeutils import Recoder, endian_codes, make_dt_codes -pdcm, have_dicom, _ = optional_package('pydicom') +if ty.TYPE_CHECKING: + import pydicom as pdcm + + have_dicom = True + DicomDataset = pdcm.Dataset +else: + pdcm, have_dicom, _ = optional_package('pydicom') + if have_dicom: + DicomDataset = pdcm.Dataset + else: + DicomDataset = ty.Any + +T = TypeVar('T', default=bytes) # nifti1 flat header definition for Analyze-like first 348 bytes # first number in comments indicates offset in file header in bytes @@ -283,15 +303,38 @@ ) -class Nifti1Extension: - """Baseclass for NIfTI1 header extensions. +class NiftiExtension(ty.Generic[T]): + """Base class for NIfTI header extensions. - This class is sufficient to handle very simple text-based extensions, such - as `comment`. More sophisticated extensions should/will be supported by - dedicated subclasses. + This class provides access to the extension content in various forms. + For simple extensions that expose data as bytes, text or JSON, this class + is sufficient. More complex extensions should be implemented as subclasses + that provide custom serialization/deserialization methods. + + Efficiency note: + + This class assumes that the runtime representation of the extension content + is mutable. Once a runtime representation is set, it is cached and will be + serialized on any attempt to access the extension content as bytes, including + determining the size of the extension in the NIfTI file. + + If the runtime representation is never accessed, the raw bytes will be used + without modification. While avoiding unnecessary deserialization, if there + are bytestrings that do not produce a valid runtime representation, they will + be written as-is, and may cause errors downstream. """ - def __init__(self, code, content): + code: int + encoding: ty.Optional[str] = None + _content: bytes + _object: ty.Optional[T] = None + + def __init__( + self, + code: ty.Union[int, str], + content: bytes = b'', + object: ty.Optional[T] = None, + ) -> None: """ Parameters ---------- @@ -299,94 +342,129 @@ def __init__(self, code, content): Canonical extension code as defined in the NIfTI standard, given either as integer or corresponding label (see :data:`~nibabel.nifti1.extension_codes`) - content : str - Extension content as read from the NIfTI file header. This content is - converted into a runtime representation. + content : bytes, optional + Extension content as read from the NIfTI file header. + object : optional + Extension content in runtime form. """ try: - self._code = extension_codes.code[code] + self.code = extension_codes.code[code] # type: ignore[assignment] except KeyError: - # XXX or fail or at least complain? - self._code = code - self._content = self._unmangle(content) - - def _unmangle(self, value): - """Convert the extension content into its runtime representation. + self.code = code # type: ignore[assignment] + self._content = content + if object is not None: + self._object = object - The default implementation does nothing at all. + @classmethod + def from_bytes(cls, content: bytes) -> Self: + """Create an extension from raw bytes. - Parameters - ---------- - value : str - Extension content as read from file. + This constructor may only be used in extension classes with a class + attribute `code` to indicate the extension type. + """ + if not hasattr(cls, 'code'): + raise NotImplementedError('from_bytes() requires a class attribute `code`') + return cls(cls.code, content=content) - Returns - ------- - The same object that was passed as `value`. + @classmethod + def from_object(cls, obj: T) -> Self: + """Create an extension from a runtime object. - Notes - ----- - Subclasses should reimplement this method to provide the desired - unmangling procedure and may return any type of object. + This constructor may only be used in extension classes with a class + attribute `code` to indicate the extension type. """ - return value - - def _mangle(self, value): - """Convert the extension content into NIfTI file header representation. + if not hasattr(cls, 'code'): + raise NotImplementedError('from_object() requires a class attribute `code`') + return cls(cls.code, object=obj) - The default implementation does nothing at all. + # Handle (de)serialization of extension content + # Subclasses may implement these methods to provide an alternative + # view of the extension content. If left unimplemented, the content + # must be bytes and is not modified. + def _mangle(self, obj: T) -> bytes: + raise NotImplementedError - Parameters - ---------- - value : str - Extension content in runtime form. + def _unmangle(self, content: bytes) -> T: + raise NotImplementedError - Returns - ------- - str + def _sync(self) -> None: + """Synchronize content with object. - Notes - ----- - Subclasses should reimplement this method to provide the desired - mangling procedure. + This permits the runtime representation to be modified in-place + and updates the bytes representation accordingly. """ - return value + if self._object is not None: + self._content = self._mangle(self._object) + + def __repr__(self) -> str: + try: + code = extension_codes.label[self.code] + except KeyError: + # deal with unknown codes + code = self.code + return f'{self.__class__.__name__}({code}, {self._content!r})' + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, self.__class__) + and self.code == other.code + and self.content == other.content + ) + + def __ne__(self, other): + return not self == other def get_code(self): """Return the canonical extension type code.""" - return self._code - - def get_content(self): - """Return the extension content in its runtime representation.""" + return self.code + + # Canonical access to extension content + # Follows the lead of httpx.Response .content, .text and .json() + # properties/methods + @property + def content(self) -> bytes: + """Return the extension content as raw bytes.""" + self._sync() return self._content - def get_sizeondisk(self): - """Return the size of the extension in the NIfTI file.""" - # need raw value size plus 8 bytes for esize and ecode - size = len(self._mangle(self._content)) - size += 8 - # extensions size has to be a multiple of 16 bytes - if size % 16 != 0: - size += 16 - (size % 16) - return size + @property + def text(self) -> str: + """Attempt to decode the extension content as text. - def __repr__(self): - try: - code = extension_codes.label[self._code] - except KeyError: - # deal with unknown codes - code = self._code + The encoding is determined by the `encoding` attribute, which may be + set by the user or subclass. If not set, the default encoding is 'utf-8'. + """ + return self.content.decode(self.encoding or 'utf-8') + + def json(self) -> ty.Any: + """Attempt to decode the extension content as JSON. - s = f"Nifti1Extension('{code}', '{self._content}')" - return s + If the content is not valid JSON, a JSONDecodeError or UnicodeDecodeError + will be raised. + """ + return json.loads(self.content) - def __eq__(self, other): - return (self._code, self._content) == (other._code, other._content) + def get_object(self) -> T: + """Return the extension content in its runtime representation. - def __ne__(self, other): - return not self == other + This method may return a different type for each extension type. + For simple use cases, consider using ``.content``, ``.text`` or ``.json()`` + instead. + """ + if self._object is None: + self._object = self._unmangle(self._content) + return self._object - def write_to(self, fileobj, byteswap): + # Backwards compatibility + get_content = get_object + + def get_sizeondisk(self) -> int: + """Return the size of the extension in the NIfTI file.""" + # need raw value size plus 8 bytes for esize and ecode, rounded up to next 16 bytes + # Rounding C+8 up to M is done by (C+8 + (M-1)) // M * M + return (len(self.content) + 23) // 16 * 16 + + def write_to(self, fileobj: ty.BinaryIO, byteswap: bool = False) -> None: """Write header extensions to fileobj Write starts at fileobj current file position. @@ -403,21 +481,75 @@ def write_to(self, fileobj, byteswap): None """ extstart = fileobj.tell() - rawsize = self.get_sizeondisk() + rawsize = self.get_sizeondisk() # Calls _sync() # write esize and ecode first - extinfo = np.array((rawsize, self._code), dtype=np.int32) + extinfo = np.array((rawsize, self.code), dtype=np.int32) if byteswap: extinfo = extinfo.byteswap() fileobj.write(extinfo.tobytes()) - # followed by the actual extension content - # XXX if mangling upon load is implemented, it should be reverted here - fileobj.write(self._mangle(self._content)) + # followed by the actual extension content, synced above + fileobj.write(self._content) # be nice and zero out remaining part of the extension till the # next 16 byte border - fileobj.write(b'\x00' * (extstart + rawsize - fileobj.tell())) + pad = extstart + rawsize - fileobj.tell() + if pad: + fileobj.write(bytes(pad)) + + +class Nifti1Extension(NiftiExtension[T]): + """Baseclass for NIfTI1 header extensions. + + This class is sufficient to handle very simple text-based extensions, such + as `comment`. More sophisticated extensions should/will be supported by + dedicated subclasses. + """ + + code = 0 # Default to unknown extension + + def _unmangle(self, value: bytes) -> T: + """Convert the extension content into its runtime representation. + + The default implementation does nothing at all. + + Parameters + ---------- + value : str + Extension content as read from file. + + Returns + ------- + The same object that was passed as `value`. + + Notes + ----- + Subclasses should reimplement this method to provide the desired + unmangling procedure and may return any type of object. + """ + return value # type: ignore[return-value] + + def _mangle(self, value: T) -> bytes: + """Convert the extension content into NIfTI file header representation. + + The default implementation does nothing at all. + + Parameters + ---------- + value : str + Extension content in runtime form. + + Returns + ------- + str + + Notes + ----- + Subclasses should reimplement this method to provide the desired + mangling procedure. + """ + return value # type: ignore[return-value] -class Nifti1DicomExtension(Nifti1Extension): +class Nifti1DicomExtension(Nifti1Extension[DicomDataset]): """NIfTI1 DICOM header extension This class is a thin wrapper around pydicom to read a binary DICOM @@ -427,7 +559,16 @@ class Nifti1DicomExtension(Nifti1Extension): header. """ - def __init__(self, code, content, parent_hdr=None): + code = 2 + _is_implicit_VR: bool = False + _is_little_endian: bool = True + + def __init__( + self, + code: ty.Union[int, str], + content: ty.Union[bytes, DicomDataset, None] = None, + parent_hdr: ty.Optional[Nifti1Header] = None, + ) -> None: """ Parameters ---------- @@ -452,30 +593,28 @@ def __init__(self, code, content, parent_hdr=None): code should always be 2 for DICOM. """ - self._code = code - if parent_hdr: + if code != 2: + raise ValueError(f'code must be 2 for DICOM. Got {code}.') + + if content is None: + content = pdcm.Dataset() + + if parent_hdr is not None: self._is_little_endian = parent_hdr.endianness == '<' - else: - self._is_little_endian = True + if isinstance(content, pdcm.dataset.Dataset): - self._is_implicit_VR = False - self._raw_content = self._mangle(content) - self._content = content + super().__init__(code, object=content) elif isinstance(content, bytes): # Got a byte string - unmangle it - self._raw_content = content - self._is_implicit_VR = self._guess_implicit_VR() - ds = self._unmangle(content, self._is_implicit_VR, self._is_little_endian) - self._content = ds - elif content is None: # initialize a new dicom dataset - self._is_implicit_VR = False - self._content = pdcm.dataset.Dataset() + self._is_implicit_VR = self._guess_implicit_VR(content) + super().__init__(code, content=content) else: raise TypeError( f'content must be either a bytestring or a pydicom Dataset. ' f'Got {content.__class__}' ) - def _guess_implicit_VR(self): + @staticmethod + def _guess_implicit_VR(content) -> bool: """Try to guess DICOM syntax by checking for valid VRs. Without a DICOM Transfer Syntax, it's difficult to tell if Value @@ -483,19 +622,17 @@ def _guess_implicit_VR(self): This reads where the first VR would be and checks it against a list of valid VRs """ - potential_vr = self._raw_content[4:6].decode() - if potential_vr in pdcm.values.converters.keys(): - implicit_VR = False - else: - implicit_VR = True - return implicit_VR - - def _unmangle(self, value, is_implicit_VR=False, is_little_endian=True): - bio = BytesIO(value) - ds = pdcm.filereader.read_dataset(bio, is_implicit_VR, is_little_endian) - return ds + potential_vr = content[4:6].decode() + return potential_vr not in pdcm.values.converters.keys() + + def _unmangle(self, obj: bytes) -> DicomDataset: + return pdcm.filereader.read_dataset( + BytesIO(obj), + self._is_implicit_VR, + self._is_little_endian, + ) - def _mangle(self, dataset): + def _mangle(self, dataset: DicomDataset) -> bytes: bio = BytesIO() dio = pdcm.filebase.DicomFileLike(bio) dio.is_implicit_VR = self._is_implicit_VR @@ -520,6 +657,21 @@ def _mangle(self, dataset): (12, 'workflow_fwds', Nifti1Extension), (14, 'freesurfer', Nifti1Extension), (16, 'pypickle', Nifti1Extension), + (18, 'mind_ident', NiftiExtension), + (20, 'b_value', NiftiExtension), + (22, 'spherical_direction', NiftiExtension), + (24, 'dt_component', NiftiExtension), + (26, 'shc_degreeorder', NiftiExtension), + (28, 'voxbo', NiftiExtension), + (30, 'caret', NiftiExtension), + ## Defined in nibabel.cifti2.parse_cifti2 + # (32, 'cifti', Cifti2Extension), + (34, 'variable_frame_timing', NiftiExtension), + (36, 'unassigned', NiftiExtension), + (38, 'eval', NiftiExtension), + (40, 'matlab', NiftiExtension), + (42, 'quantiphyse', NiftiExtension), + (44, 'mrs', NiftiExtension[ty.Dict[str, ty.Any]]), ), fields=('code', 'label', 'handler'), ) diff --git a/nibabel/tests/test_nifti1.py b/nibabel/tests/test_nifti1.py index 5ee4fb3c1..23e71c832 100644 --- a/nibabel/tests/test_nifti1.py +++ b/nibabel/tests/test_nifti1.py @@ -1224,6 +1224,32 @@ def test_ext_eq(): assert not ext == ext2 +def test_extension_content_access(): + ext = Nifti1Extension('comment', b'123') + # Unmangled content access + assert ext.get_content() == b'123' + + # Raw, text and JSON access + assert ext.content == b'123' + assert ext.text == '123' + assert ext.json() == 123 + + # Encoding can be set + ext.encoding = 'ascii' + assert ext.text == '123' + + # Test that encoding errors are caught + ascii_ext = Nifti1Extension('comment', 'hôpital'.encode('utf-8')) + ascii_ext.encoding = 'ascii' + with pytest.raises(UnicodeDecodeError): + ascii_ext.text + + json_ext = Nifti1Extension('unknown', b'{"a": 1}') + assert json_ext.content == b'{"a": 1}' + assert json_ext.text == '{"a": 1}' + assert json_ext.json() == {'a': 1} + + def test_extension_codes(): for k in extension_codes.keys(): Nifti1Extension(k, 'somevalue') @@ -1339,7 +1365,7 @@ def test_nifti_dicom_extension(): dcmbytes_explicit = struct.pack('') # Big Endian Nifti1Header dcmext = Nifti1DicomExtension(2, dcmbytes_explicit_be, parent_hdr=hdr_be) assert dcmext.__class__ == Nifti1DicomExtension - assert dcmext._guess_implicit_VR() is False + assert dcmext._is_implicit_VR is False assert dcmext.get_code() == 2 assert dcmext.get_content().PatientID == 'NiPy' assert dcmext.get_content()[0x10, 0x20].value == 'NiPy' diff --git a/pyproject.toml b/pyproject.toml index ff5168f9c..34d9f7bb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "numpy >=1.20", "packaging >=17", "importlib_resources >=5.12; python_version < '3.12'", + "typing_extensions >=4.6; python_version < '3.13'", ] classifiers = [ "Development Status :: 5 - Production/Stable", diff --git a/tox.ini b/tox.ini index 5df35c8d3..bd99d986c 100644 --- a/tox.ini +++ b/tox.ini @@ -77,7 +77,8 @@ extras = test deps = # General minimum dependencies: pin based on API usage min: packaging ==17 - min: importlib_resources ==1.3; python_version < '3.9' + min: importlib_resources ==5.12; python_version < '3.12' + min: typing_extensions ==4.6; python_version < '3.13' # NEP29/SPEC0 + 1yr: Test on minor release series within the last 3 years # We're extending this to all optional dependencies # This only affects the range that we test on; numpy is the only non-optional