From fd43cbf31e5792fdafd08fd083b0d0dcab4eb886 Mon Sep 17 00:00:00 2001 From: Nathan Zimmerman Date: Wed, 5 Feb 2025 13:40:38 -0600 Subject: [PATCH 1/9] Fix fill_value serialization of NaN --- src/zarr/core/metadata/v2.py | 53 ++++++++++--- src/zarr/testing/stateful.py | 2 +- src/zarr/testing/strategies.py | 136 ++++++++++++++++++++++++++++++++- tests/test_properties.py | 101 ++++++++++++++++++++++++ 4 files changed, 278 insertions(+), 14 deletions(-) diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index 3d292c81b4..dd48dd39dc 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -19,6 +19,7 @@ from zarr.core.common import ChunkCoords import json +import numbers from dataclasses import dataclass, field, fields, replace import numcodecs @@ -149,18 +150,20 @@ def _json_convert( json_indent = config.get("json_indent") return { ZARRAY_JSON: prototype.buffer.from_bytes( - json.dumps(zarray_dict, default=_json_convert, indent=json_indent).encode() + json.dumps( + zarray_dict, default=_json_convert, indent=json_indent, allow_nan=False + ).encode() ), ZATTRS_JSON: prototype.buffer.from_bytes( - json.dumps(zattrs_dict, indent=json_indent).encode() + json.dumps(zattrs_dict, indent=json_indent, allow_nan=False).encode() ), } @classmethod def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: - # make a copy to protect the original from modification + # Make a copy to protect the original from modification. _data = data.copy() - # check that the zarr_format attribute is correct + # Check that the zarr_format attribute is correct. _ = parse_zarr_format(_data.pop("zarr_format")) dtype = parse_dtype(_data["dtype"]) @@ -169,20 +172,46 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: if fill_value_encoded is not None: fill_value = base64.standard_b64decode(fill_value_encoded) _data["fill_value"] = fill_value - - # zarr v2 allowed arbitrary keys here. - # We don't want the ArrayV2Metadata constructor to fail just because someone put an - # extra key in the metadata. + else: + fill_value = _data.get("fill_value") + if fill_value is not None: + if np.issubdtype(dtype, np.datetime64): + if fill_value == "NaT": + _data["fill_value"] = np.array("NaT", dtype=dtype)[()] + else: + _data["fill_value"] = np.array(fill_value, dtype=dtype)[()] + elif dtype.kind == "c" and isinstance(fill_value, list): + if len(fill_value) == 2: + val = complex(float(fill_value[0]), float(fill_value[1])) + _data["fill_value"] = np.array(val, dtype=dtype)[()] + elif dtype.kind in "f" and isinstance(fill_value, str): + if fill_value in {"NaN", "Infinity", "-Infinity"}: + _data["fill_value"] = np.array(fill_value, dtype=dtype)[()] + # zarr v2 allowed arbitrary keys in the metadata. + # Filter the keys to only those expected by the constructor. expected = {x.name for x in fields(cls)} - # https://github.com/zarr-developers/zarr-python/issues/2269 - # handle the renames expected |= {"dtype", "chunks"} - _data = {k: v for k, v in _data.items() if k in expected} return cls(**_data) def to_dict(self) -> dict[str, JSON]: + def _sanitize_fill_value(fv: Any): + if fv is None: + return fv + elif isinstance(fv, np.datetime64): + if np.isnat(fv): + return "NaT" + return np.datetime_as_string(fv) + elif isinstance(fv, numbers.Real): + if np.isnan(fv): + fv = "NaN" + elif np.isinf(fv): + fv = "Infinity" if fv > 0 else "-Infinity" + elif isinstance(fv, numbers.Complex): + fv = [_sanitize_fill_value(fv.real), _sanitize_fill_value(fv.imag)] + return fv + zarray_dict = super().to_dict() if self.dtype.kind in "SV" and self.fill_value is not None: @@ -192,6 +221,7 @@ def to_dict(self) -> dict[str, JSON]: fill_value = base64.standard_b64encode(cast(bytes, self.fill_value)).decode("ascii") zarray_dict["fill_value"] = fill_value + zarray_dict["fill_value"] = _sanitize_fill_value(zarray_dict["fill_value"]) _ = zarray_dict.pop("dtype") dtype_json: JSON # In the case of zarr v2, the simplest i.e., '|VXX' dtype is represented as a string @@ -300,7 +330,6 @@ def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any: ------- An instance of `dtype`, or `None`, or any python object (in the case of an object dtype) """ - if fill_value is None or dtype.hasobject: # no fill value pass diff --git a/src/zarr/testing/stateful.py b/src/zarr/testing/stateful.py index 3e8dbcdf04..105a8cdb42 100644 --- a/src/zarr/testing/stateful.py +++ b/src/zarr/testing/stateful.py @@ -85,7 +85,7 @@ def add_group(self, name: str, data: DataObject) -> None: @rule( data=st.data(), name=node_names, - array_and_chunks=np_array_and_chunks(arrays=numpy_arrays(zarr_formats=st.just(3))), + array_and_chunks=np_array_and_chunks(nparrays=numpy_arrays(zarr_formats=st.just(3))), ) def add_array( self, diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index 0e25e44592..dbc2015e38 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -3,8 +3,9 @@ import hypothesis.extra.numpy as npst import hypothesis.strategies as st +import numcodecs import numpy as np -from hypothesis import given, settings # noqa: F401 +from hypothesis import assume, given, settings # noqa: F401 from hypothesis.strategies import SearchStrategy import zarr @@ -344,3 +345,136 @@ def make_request(start: int, length: int) -> RangeByteRequest: ) key_tuple = st.tuples(keys, byte_ranges) return st.lists(key_tuple, min_size=1, max_size=10) + + +def simple_text(): + """A strategy for generating simple text strings.""" + return st.text(st.characters(min_codepoint=32, max_codepoint=126), min_size=1, max_size=10) + + +def simple_attrs(): + """A strategy for generating simple attribute dictionaries.""" + return st.dictionaries( + simple_text(), + st.one_of( + st.integers(), + st.floats(allow_nan=False, allow_infinity=False), + st.booleans(), + simple_text(), + ), + ) + + +def array_shapes(min_dims=1, max_dims=3, max_len=100): + """A strategy for generating array shapes.""" + return st.lists( + st.integers(min_value=1, max_value=max_len), min_size=min_dims, max_size=max_dims + ) + + +# def zarr_compressors(): +# """A strategy for generating Zarr compressors.""" +# return st.sampled_from([None, Blosc(), GZip(), Zstd(), LZ4()]) + + +# def zarr_codecs(): +# """A strategy for generating Zarr codecs.""" +# return st.sampled_from([BytesCodec(), Blosc(), GZip(), Zstd(), LZ4()]) + + +def zarr_filters(): + """A strategy for generating Zarr filters.""" + return st.lists( + st.just(numcodecs.Delta(dtype="i4")), min_size=0, max_size=2 + ) # Example filter, expand as needed + + +def zarr_storage_transformers(): + """A strategy for generating Zarr storage transformers.""" + return st.lists( + st.dictionaries( + simple_text(), st.one_of(st.integers(), st.floats(), st.booleans(), simple_text()) + ), + min_size=0, + max_size=2, + ) + + +@st.composite +def array_metadata_v2(draw: st.DrawFn) -> ArrayV2Metadata: + """Generates valid ArrayV2Metadata objects for property-based testing.""" + dims = draw(st.integers(min_value=1, max_value=3)) # Limit dimensions for complexity + shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100))) + max_chunk_len = max(shape) if shape else 100 + chunks = tuple( + draw( + st.lists( + st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims + ) + ) + ) + + # Validate shape and chunks relationship + assume(all(c <= s for s, c in zip(shape, chunks, strict=False))) # Chunk size must be <= shape + + dtype = draw(v2_dtypes()) + fill_value = draw(st.one_of([st.none(), npst.from_dtype(dtype)])) + order = draw(st.sampled_from(["C", "F"])) + dimension_separator = draw(st.sampled_from([".", "/"])) + # compressor = draw(zarr_compressors()) + filters = tuple(draw(zarr_filters())) if draw(st.booleans()) else None + attributes = draw(simple_attrs()) + + # Construct the metadata object. Type hints are crucial here for correctness. + return ArrayV2Metadata( + shape=shape, + dtype=dtype, + chunks=chunks, + fill_value=fill_value, + order=order, + dimension_separator=dimension_separator, + # compressor=compressor, + filters=filters, + attributes=attributes, + ) + + +@st.composite +def array_metadata_v3(draw: st.DrawFn) -> ArrayV3Metadata: + """Generates valid ArrayV3Metadata objects for property-based testing.""" + dims = draw(st.integers(min_value=1, max_value=3)) + shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100))) + max_chunk_len = max(shape) if shape else 100 + chunks = tuple( + draw( + st.lists( + st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims + ) + ) + ) + assume(all(c <= s for s, c in zip(shape, chunks, strict=False))) + + dtype = draw(v3_dtypes()) + fill_value = draw(npst.from_dtype(dtype)) + chunk_grid = RegularChunkGrid(chunks) # Ensure chunks is passed as tuple. + chunk_key_encoding = DefaultChunkKeyEncoding(separator="/") # Or st.sampled_from(["/", "."]) + # codecs = tuple(draw(st.lists(zarr_codecs(), min_size=0, max_size=3))) + attributes = draw(simple_attrs()) + dimension_names = ( + tuple(draw(st.lists(st.one_of(st.none(), simple_text()), min_size=dims, max_size=dims))) + if draw(st.booleans()) + else None + ) + storage_transformers = tuple(draw(zarr_storage_transformers())) + + return ArrayV3Metadata( + shape=shape, + data_type=dtype, + chunk_grid=chunk_grid, + chunk_key_encoding=chunk_key_encoding, + fill_value=fill_value, + # codecs=codecs, + attributes=attributes, + dimension_names=dimension_names, + storage_transformers=storage_transformers, + ) diff --git a/tests/test_properties.py b/tests/test_properties.py index acecd44810..07b9bb36fa 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -1,3 +1,7 @@ +import dataclasses +import json + +import numpy as np import pytest from numpy.testing import assert_array_equal @@ -10,10 +14,12 @@ from hypothesis import assume, given from zarr.abc.store import Store +from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.sync import sync from zarr.testing.strategies import ( array_metadata, + array_metadata_v2, arrays, basic_indices, numpy_arrays, @@ -23,6 +29,60 @@ ) +def deep_equal(a, b): + """Deep equality check w/ NaN e to handle array metadata serialization and deserialization behaviors""" + if isinstance(a, (complex, np.complexfloating)) and isinstance( + b, (complex, np.complexfloating) + ): + # Convert to Python float to force standard NaN handling. + a_real, a_imag = float(a.real), float(a.imag) + b_real, b_imag = float(b.real), float(b.imag) + # If both parts are NaN, consider them equal. + if np.isnan(a_real) and np.isnan(b_real): + real_eq = True + else: + real_eq = a_real == b_real + if np.isnan(a_imag) and np.isnan(b_imag): + imag_eq = True + else: + imag_eq = a_imag == b_imag + return real_eq and imag_eq + + # Handle floats (including numpy floating types) and treat NaNs as equal. + if isinstance(a, (float, np.floating)) and isinstance(b, (float, np.floating)): + if np.isnan(a) and np.isnan(b): + return True + return a == b + + # Handle numpy.datetime64 values, treating NaT as equal. + if isinstance(a, np.datetime64) and isinstance(b, np.datetime64): + if np.isnat(a) and np.isnat(b): + return True + return a == b + + # Handle numpy arrays. + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + if a.shape != b.shape: + return False + # Compare elementwise. + return all(deep_equal(x, y) for x, y in zip(a.flat, b.flat, strict=False)) + + # Handle dictionaries. + if isinstance(a, dict) and isinstance(b, dict): + if set(a.keys()) != set(b.keys()): + return False + return all(deep_equal(a[k], b[k]) for k in a) + + # Handle lists and tuples. + if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)): + if len(a) != len(b): + return False + return all(deep_equal(x, y) for x, y in zip(a, b, strict=False)) + + # Fallback to default equality. + return a == b + + @given(data=st.data(), zarr_format=zarr_formats) def test_roundtrip(data: st.DataObject, zarr_format: int) -> None: nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format))) @@ -135,3 +195,44 @@ async def test_roundtrip_array_metadata( # nparray = data.draw(np_arrays) # zarray = data.draw(arrays(arrays=st.just(nparray))) # assert_array_equal(nparray, zarray[:]) + + +@given(array_metadata_v2()) +def test_v2meta_roundtrip(metadata): + buffer_dict = metadata.to_buffer_dict(prototype=default_buffer_prototype()) + zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode()) + zattrs_dict = json.loads(buffer_dict[ZATTRS_JSON].to_bytes().decode()) + + # zattrs and zarray are separate in v2, we have to add attributes back prior to `from_dict` + zarray_dict["attributes"] = zattrs_dict + + metadata_roundtripped = ArrayV2Metadata.from_dict(zarray_dict) + + # Convert both metadata instances to dictionaries. + orig = dataclasses.asdict(metadata) + rt = dataclasses.asdict(metadata_roundtripped) + + assert deep_equal(orig, rt), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}" + + +@given(npst.from_dtype(dtype=np.dtype("float64"), allow_nan=True, allow_infinity=True)) +def test_v2meta_nan_and_infinity(fill_value): + metadata = ArrayV2Metadata( + shape=[10], + dtype=np.dtype("float64"), + chunks=[5], + fill_value=fill_value, + order="C", + ) + + buffer_dict = metadata.to_buffer_dict(prototype=default_buffer_prototype()) + zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode()) + + if np.isnan(fill_value): + assert zarray_dict["fill_value"] == "NaN" + elif np.isinf(fill_value) and fill_value > 0: + assert zarray_dict["fill_value"] == "Infinity" + elif np.isinf(fill_value): + assert zarray_dict["fill_value"] == "-Infinity" + else: + assert zarray_dict["fill_value"] == fill_value From 6301b15d393517973366ca9c8836001883079186 Mon Sep 17 00:00:00 2001 From: Nathan Zimmerman Date: Tue, 18 Feb 2025 13:27:02 -0600 Subject: [PATCH 2/9] Round trip serialization for array metadata v2/v3 --- changes/2802.fix.rst | 1 + src/zarr/core/metadata/v2.py | 25 +++-- src/zarr/testing/stateful.py | 2 +- src/zarr/testing/strategies.py | 134 --------------------------- tests/test_properties.py | 162 ++++++++++++++++++++++----------- 5 files changed, 124 insertions(+), 200 deletions(-) create mode 100644 changes/2802.fix.rst diff --git a/changes/2802.fix.rst b/changes/2802.fix.rst new file mode 100644 index 0000000000..471ddf66f4 --- /dev/null +++ b/changes/2802.fix.rst @@ -0,0 +1 @@ +Fix `fill_value` serialization for `NaN` in `ArrayV2Metadata` and add property-based testing of round-trip serialization \ No newline at end of file diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index dd48dd39dc..0c32f1b062 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -170,7 +170,7 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: if dtype.kind in "SV": fill_value_encoded = _data.get("fill_value") if fill_value_encoded is not None: - fill_value = base64.standard_b64decode(fill_value_encoded) + fill_value: Any = base64.standard_b64decode(fill_value_encoded) _data["fill_value"] = fill_value else: fill_value = _data.get("fill_value") @@ -180,13 +180,11 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: _data["fill_value"] = np.array("NaT", dtype=dtype)[()] else: _data["fill_value"] = np.array(fill_value, dtype=dtype)[()] - elif dtype.kind == "c" and isinstance(fill_value, list): - if len(fill_value) == 2: - val = complex(float(fill_value[0]), float(fill_value[1])) - _data["fill_value"] = np.array(val, dtype=dtype)[()] - elif dtype.kind in "f" and isinstance(fill_value, str): - if fill_value in {"NaN", "Infinity", "-Infinity"}: - _data["fill_value"] = np.array(fill_value, dtype=dtype)[()] + elif dtype.kind == "c" and isinstance(fill_value, list) and len(fill_value) == 2: + val = complex(float(fill_value[0]), float(fill_value[1])) + _data["fill_value"] = np.array(val, dtype=dtype)[()] + elif dtype.kind in "f" and fill_value in {"NaN", "Infinity", "-Infinity"}: + _data["fill_value"] = np.array(fill_value, dtype=dtype)[()] # zarr v2 allowed arbitrary keys in the metadata. # Filter the keys to only those expected by the constructor. expected = {x.name for x in fields(cls)} @@ -196,7 +194,7 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: return cls(**_data) def to_dict(self) -> dict[str, JSON]: - def _sanitize_fill_value(fv: Any): + def _sanitize_fill_value(fv: Any) -> JSON: if fv is None: return fv elif isinstance(fv, np.datetime64): @@ -204,13 +202,14 @@ def _sanitize_fill_value(fv: Any): return "NaT" return np.datetime_as_string(fv) elif isinstance(fv, numbers.Real): - if np.isnan(fv): + float_fv = float(fv) + if np.isnan(float_fv): fv = "NaN" - elif np.isinf(fv): - fv = "Infinity" if fv > 0 else "-Infinity" + elif np.isinf(float_fv): + fv = "Infinity" if float_fv > 0 else "-Infinity" elif isinstance(fv, numbers.Complex): fv = [_sanitize_fill_value(fv.real), _sanitize_fill_value(fv.imag)] - return fv + return cast(JSON, fv) zarray_dict = super().to_dict() diff --git a/src/zarr/testing/stateful.py b/src/zarr/testing/stateful.py index 105a8cdb42..3e8dbcdf04 100644 --- a/src/zarr/testing/stateful.py +++ b/src/zarr/testing/stateful.py @@ -85,7 +85,7 @@ def add_group(self, name: str, data: DataObject) -> None: @rule( data=st.data(), name=node_names, - array_and_chunks=np_array_and_chunks(nparrays=numpy_arrays(zarr_formats=st.just(3))), + array_and_chunks=np_array_and_chunks(arrays=numpy_arrays(zarr_formats=st.just(3))), ) def add_array( self, diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index dbc2015e38..63dbfc2f1e 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -3,7 +3,6 @@ import hypothesis.extra.numpy as npst import hypothesis.strategies as st -import numcodecs import numpy as np from hypothesis import assume, given, settings # noqa: F401 from hypothesis.strategies import SearchStrategy @@ -345,136 +344,3 @@ def make_request(start: int, length: int) -> RangeByteRequest: ) key_tuple = st.tuples(keys, byte_ranges) return st.lists(key_tuple, min_size=1, max_size=10) - - -def simple_text(): - """A strategy for generating simple text strings.""" - return st.text(st.characters(min_codepoint=32, max_codepoint=126), min_size=1, max_size=10) - - -def simple_attrs(): - """A strategy for generating simple attribute dictionaries.""" - return st.dictionaries( - simple_text(), - st.one_of( - st.integers(), - st.floats(allow_nan=False, allow_infinity=False), - st.booleans(), - simple_text(), - ), - ) - - -def array_shapes(min_dims=1, max_dims=3, max_len=100): - """A strategy for generating array shapes.""" - return st.lists( - st.integers(min_value=1, max_value=max_len), min_size=min_dims, max_size=max_dims - ) - - -# def zarr_compressors(): -# """A strategy for generating Zarr compressors.""" -# return st.sampled_from([None, Blosc(), GZip(), Zstd(), LZ4()]) - - -# def zarr_codecs(): -# """A strategy for generating Zarr codecs.""" -# return st.sampled_from([BytesCodec(), Blosc(), GZip(), Zstd(), LZ4()]) - - -def zarr_filters(): - """A strategy for generating Zarr filters.""" - return st.lists( - st.just(numcodecs.Delta(dtype="i4")), min_size=0, max_size=2 - ) # Example filter, expand as needed - - -def zarr_storage_transformers(): - """A strategy for generating Zarr storage transformers.""" - return st.lists( - st.dictionaries( - simple_text(), st.one_of(st.integers(), st.floats(), st.booleans(), simple_text()) - ), - min_size=0, - max_size=2, - ) - - -@st.composite -def array_metadata_v2(draw: st.DrawFn) -> ArrayV2Metadata: - """Generates valid ArrayV2Metadata objects for property-based testing.""" - dims = draw(st.integers(min_value=1, max_value=3)) # Limit dimensions for complexity - shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100))) - max_chunk_len = max(shape) if shape else 100 - chunks = tuple( - draw( - st.lists( - st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims - ) - ) - ) - - # Validate shape and chunks relationship - assume(all(c <= s for s, c in zip(shape, chunks, strict=False))) # Chunk size must be <= shape - - dtype = draw(v2_dtypes()) - fill_value = draw(st.one_of([st.none(), npst.from_dtype(dtype)])) - order = draw(st.sampled_from(["C", "F"])) - dimension_separator = draw(st.sampled_from([".", "/"])) - # compressor = draw(zarr_compressors()) - filters = tuple(draw(zarr_filters())) if draw(st.booleans()) else None - attributes = draw(simple_attrs()) - - # Construct the metadata object. Type hints are crucial here for correctness. - return ArrayV2Metadata( - shape=shape, - dtype=dtype, - chunks=chunks, - fill_value=fill_value, - order=order, - dimension_separator=dimension_separator, - # compressor=compressor, - filters=filters, - attributes=attributes, - ) - - -@st.composite -def array_metadata_v3(draw: st.DrawFn) -> ArrayV3Metadata: - """Generates valid ArrayV3Metadata objects for property-based testing.""" - dims = draw(st.integers(min_value=1, max_value=3)) - shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100))) - max_chunk_len = max(shape) if shape else 100 - chunks = tuple( - draw( - st.lists( - st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims - ) - ) - ) - assume(all(c <= s for s, c in zip(shape, chunks, strict=False))) - - dtype = draw(v3_dtypes()) - fill_value = draw(npst.from_dtype(dtype)) - chunk_grid = RegularChunkGrid(chunks) # Ensure chunks is passed as tuple. - chunk_key_encoding = DefaultChunkKeyEncoding(separator="/") # Or st.sampled_from(["/", "."]) - # codecs = tuple(draw(st.lists(zarr_codecs(), min_size=0, max_size=3))) - attributes = draw(simple_attrs()) - dimension_names = ( - tuple(draw(st.lists(st.one_of(st.none(), simple_text()), min_size=dims, max_size=dims))) - if draw(st.booleans()) - else None - ) - storage_transformers = tuple(draw(zarr_storage_transformers())) - - return ArrayV3Metadata( - shape=shape, - data_type=dtype, - chunk_grid=chunk_grid, - chunk_key_encoding=chunk_key_encoding, - fill_value=fill_value, - # codecs=codecs, - attributes=attributes, - dimension_names=dimension_names, - storage_transformers=storage_transformers, - ) diff --git a/tests/test_properties.py b/tests/test_properties.py index 07b9bb36fa..0cc995c9fc 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -1,5 +1,6 @@ import dataclasses import json +import numbers import numpy as np import pytest @@ -14,12 +15,11 @@ from hypothesis import assume, given from zarr.abc.store import Store -from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON +from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.sync import sync from zarr.testing.strategies import ( array_metadata, - array_metadata_v2, arrays, basic_indices, numpy_arrays, @@ -30,14 +30,12 @@ def deep_equal(a, b): - """Deep equality check w/ NaN e to handle array metadata serialization and deserialization behaviors""" + """Deep equality check with handling of special cases for array metadata classes""" if isinstance(a, (complex, np.complexfloating)) and isinstance( b, (complex, np.complexfloating) ): - # Convert to Python float to force standard NaN handling. a_real, a_imag = float(a.real), float(a.imag) b_real, b_imag = float(b.real), float(b.imag) - # If both parts are NaN, consider them equal. if np.isnan(a_real) and np.isnan(b_real): real_eq = True else: @@ -48,43 +46,36 @@ def deep_equal(a, b): imag_eq = a_imag == b_imag return real_eq and imag_eq - # Handle floats (including numpy floating types) and treat NaNs as equal. if isinstance(a, (float, np.floating)) and isinstance(b, (float, np.floating)): if np.isnan(a) and np.isnan(b): return True return a == b - # Handle numpy.datetime64 values, treating NaT as equal. if isinstance(a, np.datetime64) and isinstance(b, np.datetime64): if np.isnat(a) and np.isnat(b): return True return a == b - # Handle numpy arrays. if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): if a.shape != b.shape: return False - # Compare elementwise. return all(deep_equal(x, y) for x, y in zip(a.flat, b.flat, strict=False)) - # Handle dictionaries. if isinstance(a, dict) and isinstance(b, dict): if set(a.keys()) != set(b.keys()): return False return all(deep_equal(a[k], b[k]) for k in a) - # Handle lists and tuples. if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)): if len(a) != len(b): return False return all(deep_equal(x, y) for x, y in zip(a, b, strict=False)) - # Fallback to default equality. return a == b @given(data=st.data(), zarr_format=zarr_formats) -def test_roundtrip(data: st.DataObject, zarr_format: int) -> None: +def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None: nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format))) zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format))) assert_array_equal(nparray, zarray[:]) @@ -163,9 +154,17 @@ def test_vindex(data: st.DataObject) -> None: @given(store=stores, meta=array_metadata()) # type: ignore[misc] -async def test_roundtrip_array_metadata( +async def test_roundtrip_array_metadata_from_store( store: Store, meta: ArrayV2Metadata | ArrayV3Metadata ) -> None: + """ + Verify that the I/O for metadata in a store are lossless. + + This test serializes an ArrayV2Metadata or ArrayV3Metadata object to a dict + of buffers via `to_buffer_dict`, writes each buffer to a store under keys + prefixed with "0/", and then reads them back. The test asserts that each + retrieved buffer exactly matches the original buffer. + """ asdict = meta.to_buffer_dict(prototype=default_buffer_prototype()) for key, expected in asdict.items(): await store.set(f"0/{key}", expected) @@ -173,6 +172,41 @@ async def test_roundtrip_array_metadata( assert actual == expected +@given(data=st.data(), zarr_format=zarr_formats) +def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: int) -> None: + """ + Verify that JSON serialization and deserialization of metadata is lossless. + + For Zarr v2: + - The metadata is split into two JSON documents (one for array data and one + for attributes). The test merges the attributes back before deserialization. + For Zarr v3: + - All metadata is stored in a single JSON document. No manual merger is necessary. + + The test then converts both the original and round-tripped metadata objects + into dictionaries using `dataclasses.asdict` and uses a deep equality check + to verify that the roundtrip has preserved all fields (including special + cases like NaN, Infinity, complex numbers, and datetime values). + """ + metadata = data.draw(array_metadata(zarr_formats=st.just(zarr_format))) + buffer_dict = metadata.to_buffer_dict(prototype=default_buffer_prototype()) + + if zarr_format == 2: + zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode()) + zattrs_dict = json.loads(buffer_dict[ZATTRS_JSON].to_bytes().decode()) + # zattrs and zarray are separate in v2, we have to add attributes back prior to `from_dict` + zarray_dict["attributes"] = zattrs_dict + metadata_roundtripped = ArrayV2Metadata.from_dict(zarray_dict) + else: + zarray_dict = json.loads(buffer_dict[ZARR_JSON].to_bytes().decode()) + metadata_roundtripped = ArrayV3Metadata.from_dict(zarray_dict) + + orig = dataclasses.asdict(metadata) + rt = dataclasses.asdict(metadata_roundtripped) + + assert deep_equal(orig, rt), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}" + + # @st.composite # def advanced_indices(draw, *, shape): # basic_idxr = draw( @@ -197,42 +231,66 @@ async def test_roundtrip_array_metadata( # assert_array_equal(nparray, zarray[:]) -@given(array_metadata_v2()) -def test_v2meta_roundtrip(metadata): - buffer_dict = metadata.to_buffer_dict(prototype=default_buffer_prototype()) - zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode()) - zattrs_dict = json.loads(buffer_dict[ZATTRS_JSON].to_bytes().decode()) - - # zattrs and zarray are separate in v2, we have to add attributes back prior to `from_dict` - zarray_dict["attributes"] = zattrs_dict - - metadata_roundtripped = ArrayV2Metadata.from_dict(zarray_dict) - - # Convert both metadata instances to dictionaries. - orig = dataclasses.asdict(metadata) - rt = dataclasses.asdict(metadata_roundtripped) - - assert deep_equal(orig, rt), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}" - - -@given(npst.from_dtype(dtype=np.dtype("float64"), allow_nan=True, allow_infinity=True)) -def test_v2meta_nan_and_infinity(fill_value): - metadata = ArrayV2Metadata( - shape=[10], - dtype=np.dtype("float64"), - chunks=[5], - fill_value=fill_value, - order="C", - ) - - buffer_dict = metadata.to_buffer_dict(prototype=default_buffer_prototype()) - zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode()) - - if np.isnan(fill_value): - assert zarray_dict["fill_value"] == "NaN" - elif np.isinf(fill_value) and fill_value > 0: - assert zarray_dict["fill_value"] == "Infinity" - elif np.isinf(fill_value): - assert zarray_dict["fill_value"] == "-Infinity" +def serialized_float_is_valid(serialized): + """ + Validate that the serialized representation of a float conforms to the spec. + + The specification requires that a serialized float must be either: + - A JSON number, or + - One of the strings "NaN", "Infinity", or "-Infinity". + + Args: + serialized: The value produced by JSON serialization for a floating point number. + + Returns: + bool: True if the serialized value is valid according to the spec, False otherwise. + """ + if isinstance(serialized, numbers.Real): + return True + return serialized in ("NaN", "Infinity", "-Infinity") + + +@given(meta=array_metadata()) # type: ignore[misc] +def test_array_metadata_meets_spec(meta: ArrayV2Metadata | ArrayV3Metadata) -> None: + """ + Validate that the array metadata produced by the library conforms to the relevant spec (V2 vs V3). + + For ArrayV2Metadata: + - Ensures that 'zarr_format' is 2. + - Verifies that 'filters' is either None or a tuple (and not an empty tuple). + For ArrayV3Metadata: + - Ensures that 'zarr_format' is 3. + + For both versions: + - If the dtype is a floating point of some kind, verifies of fill values: + * NaN is serialized as the string "NaN" + * Positive Infinity is serialized as the string "Infinity" + * Negative Infinity is serialized as the string "-Infinity" + * Other fill values are preserved as-is. + - If the dtype is a complex number of some kind, verifies that each component of the fill + value (real and imaginary) satisfies the serialization rules for floating point numbers. + - If the dtype is a datetime of some kind, verifies that `NaT` values are serialized as "NaT". + + Note: + This test validates spec-compliance for array metadata serialization. + It is a work-in-progress and should be expanded as further edge cases are identified. + """ + asdict_dict = meta.to_dict() + + # version-specific validations + if isinstance(meta, ArrayV2Metadata): + assert asdict_dict["filters"] != () + assert asdict_dict["filters"] is None or isinstance(asdict_dict["filters"], tuple) + assert asdict_dict["zarr_format"] == 2 else: - assert zarray_dict["fill_value"] == fill_value + assert asdict_dict["zarr_format"] == 3 + + # version-agnostic validations + if meta.dtype.kind == "f": + assert serialized_float_is_valid(asdict_dict["fill_value"]) + elif meta.dtype.kind == "c": + # fill_value should be a two-element array [real, imag]. + assert serialized_float_is_valid(asdict_dict["fill_value"].real) + assert serialized_float_is_valid(asdict_dict["fill_value"].imag) + elif meta.dtype.kind == "M" and np.isnat(meta.fill_value): + assert asdict_dict["fill_value"] == "NaT" From 19d61c909e04b1546576faba8753d15e3547a7b6 Mon Sep 17 00:00:00 2001 From: Nathan Zimmerman Date: Mon, 24 Feb 2025 12:38:58 -0600 Subject: [PATCH 3/9] Unify metadata v2 fill value parsing --- src/zarr/core/metadata/v2.py | 122 ++++++++++++++++++++--------------- 1 file changed, 69 insertions(+), 53 deletions(-) diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index f5520c9055..a359e116f5 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -2,17 +2,17 @@ import base64 import warnings -from collections.abc import Iterable +from collections.abc import Iterable, Mapping, Sequence from enum import Enum from functools import cached_property -from typing import TYPE_CHECKING, TypedDict, cast +from typing import TYPE_CHECKING, Any, TypedDict, cast import numcodecs.abc from zarr.abc.metadata import Metadata if TYPE_CHECKING: - from typing import Any, Literal, Self + from typing import Literal, Self import numpy.typing as npt @@ -109,6 +109,29 @@ def shards(self) -> ChunkCoords | None: return None def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: + def _serialize_fill_value(fv: Any) -> JSON: + if self.fill_value is None: + pass + elif self.dtype.kind in "SV": + # There's a relationship between self.dtype and self.fill_value + # that mypy isn't aware of. The fact that we have S or V dtype here + # means we should have a bytes-type fill_value. + fv = base64.standard_b64encode(cast(bytes, self.fill_value)).decode("ascii") + elif isinstance(fv, np.datetime64): + if np.isnat(fv): + fv = "NaT" + else: + fv = np.datetime_as_string(fv) + elif isinstance(fv, numbers.Real): + float_fv = float(fv) + if np.isnan(float_fv): + fv = "NaN" + elif np.isinf(float_fv): + fv = "Infinity" if float_fv > 0 else "-Infinity" + elif isinstance(fv, numbers.Complex): + fv = [_serialize_fill_value(fv.real), _serialize_fill_value(fv.imag)] + return cast(JSON, fv) + def _json_convert( o: Any, ) -> Any: @@ -147,6 +170,7 @@ def _json_convert( raise TypeError zarray_dict = self.to_dict() + zarray_dict["fill_value"] = _serialize_fill_value(zarray_dict["fill_value"]) zattrs_dict = zarray_dict.pop("attributes", {}) json_indent = config.get("json_indent") return { @@ -161,38 +185,24 @@ def _json_convert( } @classmethod - def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: + def from_dict(cls, data: dict[str, JSON]) -> ArrayV2Metadata: # Make a copy to protect the original from modification. _data = data.copy() # Check that the zarr_format attribute is correct. _ = parse_zarr_format(_data.pop("zarr_format")) - dtype = parse_dtype(_data["dtype"]) - if dtype.kind in "SV": - fill_value_encoded = _data.get("fill_value") - if fill_value_encoded is not None: - fill_value: Any = base64.standard_b64decode(fill_value_encoded) - _data["fill_value"] = fill_value - else: - fill_value = _data.get("fill_value") - if fill_value is not None: - if np.issubdtype(dtype, np.datetime64): - if fill_value == "NaT": - _data["fill_value"] = np.array("NaT", dtype=dtype)[()] - else: - _data["fill_value"] = np.array(fill_value, dtype=dtype)[()] - elif dtype.kind == "c" and isinstance(fill_value, list) and len(fill_value) == 2: - val = complex(float(fill_value[0]), float(fill_value[1])) - _data["fill_value"] = np.array(val, dtype=dtype)[()] - elif dtype.kind in "f" and fill_value in {"NaN", "Infinity", "-Infinity"}: - _data["fill_value"] = np.array(fill_value, dtype=dtype)[()] # zarr v2 allowed arbitrary keys in the metadata. # Filter the keys to only those expected by the constructor. expected = {x.name for x in fields(cls)} expected |= {"dtype", "chunks"} # check if `filters` is an empty sequence; if so use None instead and raise a warning - if _data["filters"] is not None and len(_data["filters"]) == 0: + filters = _data.get("filters") + if ( + isinstance(filters, Sequence) + and not isinstance(filters, (str, bytes)) + and len(filters) == 0 + ): msg = ( "Found an empty list of filters in the array metadata document. " "This is contrary to the Zarr V2 specification, and will cause an error in the future. " @@ -203,36 +213,11 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: _data = {k: v for k, v in _data.items() if k in expected} - return cls(**_data) + return cls(**cast(Mapping[str, Any], _data)) def to_dict(self) -> dict[str, JSON]: - def _sanitize_fill_value(fv: Any) -> JSON: - if fv is None: - return fv - elif isinstance(fv, np.datetime64): - if np.isnat(fv): - return "NaT" - return np.datetime_as_string(fv) - elif isinstance(fv, numbers.Real): - float_fv = float(fv) - if np.isnan(float_fv): - fv = "NaN" - elif np.isinf(float_fv): - fv = "Infinity" if float_fv > 0 else "-Infinity" - elif isinstance(fv, numbers.Complex): - fv = [_sanitize_fill_value(fv.real), _sanitize_fill_value(fv.imag)] - return cast(JSON, fv) - zarray_dict = super().to_dict() - if self.dtype.kind in "SV" and self.fill_value is not None: - # There's a relationship between self.dtype and self.fill_value - # that mypy isn't aware of. The fact that we have S or V dtype here - # means we should have a bytes-type fill_value. - fill_value = base64.standard_b64encode(cast(bytes, self.fill_value)).decode("ascii") - zarray_dict["fill_value"] = fill_value - - zarray_dict["fill_value"] = _sanitize_fill_value(zarray_dict["fill_value"]) _ = zarray_dict.pop("dtype") dtype_json: JSON # In the case of zarr v2, the simplest i.e., '|VXX' dtype is represented as a string @@ -330,7 +315,25 @@ def parse_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata: return data -def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any: +def parse_structured_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: + """Handle structured dtype/fill value pairs""" + try: + if isinstance(fill_value, (tuple, list)): + fill_value = np.array([fill_value], dtype=dtype)[0] + elif isinstance(fill_value, bytes): + fill_value = np.frombuffer(fill_value, dtype=dtype)[0] + elif isinstance(fill_value, str): + decoded = base64.standard_b64decode(fill_value) + fill_value = np.frombuffer(decoded, dtype=dtype)[0] + else: + fill_value = np.array(fill_value, dtype=dtype)[()] + except Exception as e: + msg = f"Fill_value {fill_value} is not valid for dtype {dtype}." + raise ValueError(msg) from e + return fill_value + + +def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: """ Parse a potential fill value into a value that is compatible with the provided dtype. @@ -345,14 +348,15 @@ def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any: ------- An instance of `dtype`, or `None`, or any python object (in the case of an object dtype) """ + if fill_value is None or dtype.hasobject: - # no fill value pass + elif dtype.fields is not None: + fill_value = parse_structured_fill_value(fill_value, dtype) elif not isinstance(fill_value, np.void) and fill_value == 0: # this should be compatible across numpy versions for any array type, including # structured arrays fill_value = np.zeros((), dtype=dtype)[()] - elif dtype.kind == "U": # special case unicode because of encoding issues on Windows if passed through numpy # https://github.com/alimanfoo/zarr/pull/172#issuecomment-343782713 @@ -361,6 +365,18 @@ def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any: raise ValueError( f"fill_value {fill_value!r} is not valid for dtype {dtype}; must be a unicode string" ) + elif dtype.kind in "SV" and isinstance(fill_value, str): + fill_value = base64.standard_b64decode(fill_value) + elif np.issubdtype(dtype, np.datetime64): + if fill_value == "NaT": + fill_value = np.array("NaT", dtype=dtype)[()] + else: + fill_value = np.array(fill_value, dtype=dtype)[()] + elif dtype.kind == "c" and isinstance(fill_value, list) and len(fill_value) == 2: + complex_val = complex(float(fill_value[0]), float(fill_value[1])) + fill_value = np.array(complex_val, dtype=dtype)[()] + elif dtype.kind in "f" and fill_value in {"NaN", "Infinity", "-Infinity"}: + fill_value = np.array(fill_value, dtype=dtype)[()] else: try: if isinstance(fill_value, bytes) and dtype.kind == "V": From 29faec3c4edfc30e25a0f35dd47adab1349d0f9d Mon Sep 17 00:00:00 2001 From: Nathan Zimmerman Date: Mon, 24 Feb 2025 14:50:43 -0600 Subject: [PATCH 4/9] Test structured fill_value parsing --- src/zarr/core/group.py | 2 +- src/zarr/core/metadata/v2.py | 4 +- src/zarr/testing/strategies.py | 2 +- tests/test_properties.py | 1 + tests/test_v2.py | 84 ++++++++++++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 3 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index a7f8a6c022..b16c130b54 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -3458,7 +3458,7 @@ def _build_metadata_v3(zarr_json: dict[str, JSON]) -> ArrayV3Metadata | GroupMet def _build_metadata_v2( - zarr_json: dict[str, object], attrs_json: dict[str, JSON] + zarr_json: dict[str, JSON], attrs_json: dict[str, JSON] ) -> ArrayV2Metadata | GroupMetadata: """ Convert a dict representation of Zarr V2 metadata into the corresponding metadata class. diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index a359e116f5..0e27e20175 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -318,7 +318,9 @@ def parse_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata: def parse_structured_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: """Handle structured dtype/fill value pairs""" try: - if isinstance(fill_value, (tuple, list)): + if isinstance(fill_value, list): + fill_value = tuple(fill_value) + if isinstance(fill_value, tuple): fill_value = np.array([fill_value], dtype=dtype)[0] elif isinstance(fill_value, bytes): fill_value = np.frombuffer(fill_value, dtype=dtype)[0] diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index 96d664f5aa..358a8736f6 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -5,7 +5,7 @@ import hypothesis.extra.numpy as npst import hypothesis.strategies as st import numpy as np -from hypothesis import event, given, settings # noqa: F401 +from hypothesis import event from hypothesis.strategies import SearchStrategy import zarr diff --git a/tests/test_properties.py b/tests/test_properties.py index 7856770dc6..73b9207590 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -1,6 +1,7 @@ import dataclasses import json import numbers + import numpy as np import pytest from numpy.testing import assert_array_equal diff --git a/tests/test_v2.py b/tests/test_v2.py index 0a4487cfcc..3a1acd064f 100644 --- a/tests/test_v2.py +++ b/tests/test_v2.py @@ -15,6 +15,7 @@ from zarr import config from zarr.abc.store import Store from zarr.core.buffer.core import default_buffer_prototype +from zarr.core.metadata.v2 import parse_structured_fill_value from zarr.core.sync import sync from zarr.storage import MemoryStore, StorePath @@ -315,6 +316,89 @@ def test_structured_dtype_roundtrip(fill_value, tmp_path) -> None: assert (a == za[:]).all() +@pytest.mark.parametrize( + ( + "fill_value", + "dtype", + "expected_result", + ), + [ + ( + ("Alice", 30), + np.dtype([("name", "U10"), ("age", "i4")]), + np.array([("Alice", 30)], dtype=[("name", "U10"), ("age", "i4")])[0], + ), + ( + ["Bob", 25], + np.dtype([("name", "U10"), ("age", "i4")]), + np.array([("Bob", 25)], dtype=[("name", "U10"), ("age", "i4")])[0], + ), + ( + b"\x01\x00\x00\x00\x02\x00\x00\x00", + np.dtype([("x", "i4"), ("y", "i4")]), + np.array([(1, 2)], dtype=[("x", "i4"), ("y", "i4")])[0], + ), + ( + "BQAAAA==", + np.dtype([("val", "i4")]), + np.array([(5,)], dtype=[("val", "i4")])[0], + ), + ( + {"x": 1, "y": 2}, + np.dtype([("location", "O")]), + np.array([({"x": 1, "y": 2},)], dtype=[("location", "O")])[0], + ), + ( + {"x": 1, "y": 2, "z": 3}, + np.dtype([("location", "O")]), + np.array([({"x": 1, "y": 2, "z": 3},)], dtype=[("location", "O")])[0], + ), + ], + ids=[ + "tuple_input", + "list_input", + "bytes_input", + "string_input", + "dictionary_input", + "dictionary_input_extra_fields", + ], +) +def test_parse_structured_fill_value_valid( + fill_value: Any, dtype: np.dtype[Any], expected_result: Any +) -> None: + result = parse_structured_fill_value(fill_value, dtype) + assert result.dtype == expected_result.dtype + assert result == expected_result + if isinstance(expected_result, np.void): + for name in expected_result.dtype.names or []: + assert result[name] == expected_result[name] + + +@pytest.mark.parametrize( + ( + "fill_value", + "dtype", + ), + [ + (("Alice", 30), np.dtype([("name", "U10"), ("age", "i4"), ("city", "U20")])), + (b"\x01\x00\x00\x00", np.dtype([("x", "i4"), ("y", "i4")])), + ("this_is_not_base64", np.dtype([("val", "i4")])), + ("hello", np.dtype([("age", "i4")])), + ({"x": 1, "y": 2}, np.dtype([("location", "i4")])), + ], + ids=[ + "tuple_list_wrong_length", + "bytes_wrong_length", + "invalid_base64", + "wrong_data_type", + "wrong_dictionary", + ], +) +def test_parse_structured_fill_value_invalid(fill_value: Any, dtype: np.dtype[Any]) -> None: + with pytest.raises(ValueError): + parse_structured_fill_value(fill_value, dtype) + + @pytest.mark.parametrize("fill_value", [None, b"x"], ids=["no_fill", "fill"]) def test_other_dtype_roundtrip(fill_value, tmp_path) -> None: a = np.array([b"a\0\0", b"bb", b"ccc"], dtype="V7") From 042d815c767c25478e960a3905bbe562085b12a6 Mon Sep 17 00:00:00 2001 From: Nathan Zimmerman Date: Wed, 5 Mar 2025 13:50:21 -0600 Subject: [PATCH 5/9] Remove redundancies, fix integral handling --- src/zarr/core/metadata/v2.py | 78 +++++++++++++++++++----------------- tests/test_properties.py | 9 +++-- tests/test_v2.py | 6 +-- 3 files changed, 51 insertions(+), 42 deletions(-) diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index 0e27e20175..e39ffe9385 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -2,7 +2,7 @@ import base64 import warnings -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Iterable, Sequence from enum import Enum from functools import cached_property from typing import TYPE_CHECKING, Any, TypedDict, cast @@ -109,29 +109,6 @@ def shards(self) -> ChunkCoords | None: return None def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: - def _serialize_fill_value(fv: Any) -> JSON: - if self.fill_value is None: - pass - elif self.dtype.kind in "SV": - # There's a relationship between self.dtype and self.fill_value - # that mypy isn't aware of. The fact that we have S or V dtype here - # means we should have a bytes-type fill_value. - fv = base64.standard_b64encode(cast(bytes, self.fill_value)).decode("ascii") - elif isinstance(fv, np.datetime64): - if np.isnat(fv): - fv = "NaT" - else: - fv = np.datetime_as_string(fv) - elif isinstance(fv, numbers.Real): - float_fv = float(fv) - if np.isnan(float_fv): - fv = "NaN" - elif np.isinf(float_fv): - fv = "Infinity" if float_fv > 0 else "-Infinity" - elif isinstance(fv, numbers.Complex): - fv = [_serialize_fill_value(fv.real), _serialize_fill_value(fv.imag)] - return cast(JSON, fv) - def _json_convert( o: Any, ) -> Any: @@ -170,7 +147,7 @@ def _json_convert( raise TypeError zarray_dict = self.to_dict() - zarray_dict["fill_value"] = _serialize_fill_value(zarray_dict["fill_value"]) + zarray_dict["fill_value"] = _serialize_fill_value(self.fill_value, self.dtype) zattrs_dict = zarray_dict.pop("attributes", {}) json_indent = config.get("json_indent") return { @@ -185,7 +162,7 @@ def _json_convert( } @classmethod - def from_dict(cls, data: dict[str, JSON]) -> ArrayV2Metadata: + def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: # Make a copy to protect the original from modification. _data = data.copy() # Check that the zarr_format attribute is correct. @@ -213,7 +190,7 @@ def from_dict(cls, data: dict[str, JSON]) -> ArrayV2Metadata: _data = {k: v for k, v in _data.items() if k in expected} - return cls(**cast(Mapping[str, Any], _data)) + return cls(**_data) def to_dict(self) -> dict[str, JSON]: zarray_dict = super().to_dict() @@ -315,7 +292,7 @@ def parse_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata: return data -def parse_structured_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: +def _parse_structured_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: """Handle structured dtype/fill value pairs""" try: if isinstance(fill_value, list): @@ -354,7 +331,10 @@ def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: if fill_value is None or dtype.hasobject: pass elif dtype.fields is not None: - fill_value = parse_structured_fill_value(fill_value, dtype) + # the dtype is structured (has multiple fields), so the fill_value might be a + # compound value (e.g., a tuple or dict) that needs field-wise processing. + # We use parse_structured_fill_value to correctly convert each component. + fill_value = _parse_structured_fill_value(fill_value, dtype) elif not isinstance(fill_value, np.void) and fill_value == 0: # this should be compatible across numpy versions for any array type, including # structured arrays @@ -369,16 +349,9 @@ def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: ) elif dtype.kind in "SV" and isinstance(fill_value, str): fill_value = base64.standard_b64decode(fill_value) - elif np.issubdtype(dtype, np.datetime64): - if fill_value == "NaT": - fill_value = np.array("NaT", dtype=dtype)[()] - else: - fill_value = np.array(fill_value, dtype=dtype)[()] elif dtype.kind == "c" and isinstance(fill_value, list) and len(fill_value) == 2: complex_val = complex(float(fill_value[0]), float(fill_value[1])) fill_value = np.array(complex_val, dtype=dtype)[()] - elif dtype.kind in "f" and fill_value in {"NaN", "Infinity", "-Infinity"}: - fill_value = np.array(fill_value, dtype=dtype)[()] else: try: if isinstance(fill_value, bytes) and dtype.kind == "V": @@ -394,6 +367,39 @@ def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: return fill_value +def _serialize_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> JSON: + serialized: JSON + + if fill_value is None: + serialized = None + elif dtype.kind in "SV": + # There's a relationship between dtype and fill_value + # that mypy isn't aware of. The fact that we have S or V dtype here + # means we should have a bytes-type fill_value. + serialized = base64.standard_b64encode(cast(bytes, fill_value)).decode("ascii") + elif isinstance(fill_value, np.datetime64): + serialized = np.datetime_as_string(fill_value) + elif isinstance(fill_value, numbers.Integral): + serialized = int(fill_value) + elif isinstance(fill_value, numbers.Real): + float_fv = float(fill_value) + if np.isnan(float_fv): + serialized = "NaN" + elif np.isinf(float_fv): + serialized = "Infinity" if float_fv > 0 else "-Infinity" + else: + serialized = float_fv + elif isinstance(fill_value, numbers.Complex): + serialized = [ + _serialize_fill_value(fill_value.real, dtype), + _serialize_fill_value(fill_value.imag, dtype), + ] + else: + serialized = fill_value + + return serialized + + def _default_fill_value(dtype: np.dtype[Any]) -> Any: """ Get the default fill value for a type. diff --git a/tests/test_properties.py b/tests/test_properties.py index 73b9207590..5f7112d85f 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -1,6 +1,7 @@ import dataclasses import json import numbers +from typing import Any import numpy as np import pytest @@ -12,7 +13,7 @@ import hypothesis.extra.numpy as npst import hypothesis.strategies as st -from hypothesis import assume, given +from hypothesis import assume, given, settings from zarr.abc.store import Store from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON @@ -30,7 +31,7 @@ ) -def deep_equal(a, b): +def deep_equal(a: Any, b: Any) -> bool: """Deep equality check with handling of special cases for array metadata classes""" if isinstance(a, (complex, np.complexfloating)) and isinstance( b, (complex, np.complexfloating) @@ -100,6 +101,8 @@ def test_array_creates_implicit_groups(array): ) +# bump deadline from 200 to 300 to avoid (rare) intermittent timeouts +@settings(deadline=300) @given(data=st.data()) def test_basic_indexing(data: st.DataObject) -> None: zarray = data.draw(simple_arrays()) @@ -236,7 +239,7 @@ def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: in # assert_array_equal(nparray, zarray[:]) -def serialized_float_is_valid(serialized): +def serialized_float_is_valid(serialized: numbers.Real | str) -> bool: """ Validate that the serialized representation of a float conforms to the spec. diff --git a/tests/test_v2.py b/tests/test_v2.py index 3a1acd064f..3a36bc01fd 100644 --- a/tests/test_v2.py +++ b/tests/test_v2.py @@ -15,7 +15,7 @@ from zarr import config from zarr.abc.store import Store from zarr.core.buffer.core import default_buffer_prototype -from zarr.core.metadata.v2 import parse_structured_fill_value +from zarr.core.metadata.v2 import _parse_structured_fill_value from zarr.core.sync import sync from zarr.storage import MemoryStore, StorePath @@ -366,7 +366,7 @@ def test_structured_dtype_roundtrip(fill_value, tmp_path) -> None: def test_parse_structured_fill_value_valid( fill_value: Any, dtype: np.dtype[Any], expected_result: Any ) -> None: - result = parse_structured_fill_value(fill_value, dtype) + result = _parse_structured_fill_value(fill_value, dtype) assert result.dtype == expected_result.dtype assert result == expected_result if isinstance(expected_result, np.void): @@ -396,7 +396,7 @@ def test_parse_structured_fill_value_valid( ) def test_parse_structured_fill_value_invalid(fill_value: Any, dtype: np.dtype[Any]) -> None: with pytest.raises(ValueError): - parse_structured_fill_value(fill_value, dtype) + _parse_structured_fill_value(fill_value, dtype) @pytest.mark.parametrize("fill_value", [None, b"x"], ids=["no_fill", "fill"]) From 1388a3b4f734d37b97deee7eb08a5d7b663ec148 Mon Sep 17 00:00:00 2001 From: Nathan Zimmerman Date: Wed, 5 Mar 2025 13:54:25 -0600 Subject: [PATCH 6/9] Reorganize structured fill parsing --- src/zarr/core/metadata/v2.py | 17 ++++++++--------- tests/test_v2.py | 1 + 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index e39ffe9385..11f14b37aa 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -294,22 +294,21 @@ def parse_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata: def _parse_structured_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: """Handle structured dtype/fill value pairs""" + print("FILL VALUE", fill_value, "DT", dtype) try: if isinstance(fill_value, list): - fill_value = tuple(fill_value) - if isinstance(fill_value, tuple): - fill_value = np.array([fill_value], dtype=dtype)[0] + return np.array([tuple(fill_value)], dtype=dtype)[0] + elif isinstance(fill_value, tuple): + return np.array([fill_value], dtype=dtype)[0] elif isinstance(fill_value, bytes): - fill_value = np.frombuffer(fill_value, dtype=dtype)[0] + return np.frombuffer(fill_value, dtype=dtype)[0] elif isinstance(fill_value, str): decoded = base64.standard_b64decode(fill_value) - fill_value = np.frombuffer(decoded, dtype=dtype)[0] + return np.frombuffer(decoded, dtype=dtype)[0] else: - fill_value = np.array(fill_value, dtype=dtype)[()] + return np.array(fill_value, dtype=dtype)[()] except Exception as e: - msg = f"Fill_value {fill_value} is not valid for dtype {dtype}." - raise ValueError(msg) from e - return fill_value + raise ValueError(f"Fill_value {fill_value} is not valid for dtype {dtype}.") from e def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: diff --git a/tests/test_v2.py b/tests/test_v2.py index 3a36bc01fd..de4ba66a04 100644 --- a/tests/test_v2.py +++ b/tests/test_v2.py @@ -367,6 +367,7 @@ def test_parse_structured_fill_value_valid( fill_value: Any, dtype: np.dtype[Any], expected_result: Any ) -> None: result = _parse_structured_fill_value(fill_value, dtype) + print(result) assert result.dtype == expected_result.dtype assert result == expected_result if isinstance(expected_result, np.void): From c46b27c5c86818e0a5a00da7216b0c14333abd1e Mon Sep 17 00:00:00 2001 From: Nathan Zimmerman Date: Wed, 5 Mar 2025 14:48:27 -0600 Subject: [PATCH 7/9] Bump up hypothesis deadline --- tests/test_properties.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_properties.py b/tests/test_properties.py index 5f7112d85f..cd2201c393 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -101,8 +101,8 @@ def test_array_creates_implicit_groups(array): ) -# bump deadline from 200 to 300 to avoid (rare) intermittent timeouts -@settings(deadline=300) +# bump deadline from 200 to 500 to avoid (rare) intermittent timeouts +@settings(deadline=500) @given(data=st.data()) def test_basic_indexing(data: st.DataObject) -> None: zarray = data.draw(simple_arrays()) From 11e2520802a5c502a08bbf13756751e7af4bafdc Mon Sep 17 00:00:00 2001 From: Nathan Zimmerman Date: Wed, 5 Mar 2025 15:00:33 -0600 Subject: [PATCH 8/9] Remove hypothesis deadline --- tests/test_properties.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_properties.py b/tests/test_properties.py index cd2201c393..d48dfe2fef 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -101,8 +101,8 @@ def test_array_creates_implicit_groups(array): ) -# bump deadline from 200 to 500 to avoid (rare) intermittent timeouts -@settings(deadline=500) +# this decorator removes timeout; not ideal but it should avoid intermittent CI failures +@settings(deadline=None) @given(data=st.data()) def test_basic_indexing(data: st.DataObject) -> None: zarray = data.draw(simple_arrays()) From 5c6166fd6263f1b3206e90d6431a374a8db38987 Mon Sep 17 00:00:00 2001 From: Nathan Zimmerman Date: Wed, 2 Apr 2025 13:53:32 -0500 Subject: [PATCH 9/9] Update tests/test_v2.py Co-authored-by: David Stansby --- tests/test_v2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_v2.py b/tests/test_v2.py index de4ba66a04..3a36bc01fd 100644 --- a/tests/test_v2.py +++ b/tests/test_v2.py @@ -367,7 +367,6 @@ def test_parse_structured_fill_value_valid( fill_value: Any, dtype: np.dtype[Any], expected_result: Any ) -> None: result = _parse_structured_fill_value(fill_value, dtype) - print(result) assert result.dtype == expected_result.dtype assert result == expected_result if isinstance(expected_result, np.void):