diff --git a/changes/2802.fix.rst b/changes/2802.fix.rst new file mode 100644 index 0000000000..471ddf66f4 --- /dev/null +++ b/changes/2802.fix.rst @@ -0,0 +1 @@ +Fix `fill_value` serialization for `NaN` in `ArrayV2Metadata` and add property-based testing of round-trip serialization \ No newline at end of file diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 8e7f7f3474..5500bdd4a5 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -3456,7 +3456,7 @@ def _build_metadata_v3(zarr_json: dict[str, JSON]) -> ArrayV3Metadata | GroupMet def _build_metadata_v2( - zarr_json: dict[str, object], attrs_json: dict[str, JSON] + zarr_json: dict[str, JSON], attrs_json: dict[str, JSON] ) -> ArrayV2Metadata | GroupMetadata: """ Convert a dict representation of Zarr V2 metadata into the corresponding metadata class. diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index 823944e067..11f14b37aa 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -2,17 +2,17 @@ import base64 import warnings -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from enum import Enum from functools import cached_property -from typing import TYPE_CHECKING, TypedDict, cast +from typing import TYPE_CHECKING, Any, TypedDict, cast import numcodecs.abc from zarr.abc.metadata import Metadata if TYPE_CHECKING: - from typing import Any, Literal, Self + from typing import Literal, Self import numpy.typing as npt @@ -20,6 +20,7 @@ from zarr.core.common import ChunkCoords import json +import numbers from dataclasses import dataclass, field, fields, replace import numcodecs @@ -146,41 +147,39 @@ def _json_convert( raise TypeError zarray_dict = self.to_dict() + zarray_dict["fill_value"] = _serialize_fill_value(self.fill_value, self.dtype) zattrs_dict = zarray_dict.pop("attributes", {}) json_indent = config.get("json_indent") return { ZARRAY_JSON: prototype.buffer.from_bytes( - json.dumps(zarray_dict, default=_json_convert, indent=json_indent).encode() + json.dumps( + zarray_dict, default=_json_convert, indent=json_indent, allow_nan=False + ).encode() ), ZATTRS_JSON: prototype.buffer.from_bytes( - json.dumps(zattrs_dict, indent=json_indent).encode() + json.dumps(zattrs_dict, indent=json_indent, allow_nan=False).encode() ), } @classmethod def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: - # make a copy to protect the original from modification + # Make a copy to protect the original from modification. _data = data.copy() - # check that the zarr_format attribute is correct + # Check that the zarr_format attribute is correct. _ = parse_zarr_format(_data.pop("zarr_format")) - dtype = parse_dtype(_data["dtype"]) - if dtype.kind in "SV": - fill_value_encoded = _data.get("fill_value") - if fill_value_encoded is not None: - fill_value = base64.standard_b64decode(fill_value_encoded) - _data["fill_value"] = fill_value - - # zarr v2 allowed arbitrary keys here. - # We don't want the ArrayV2Metadata constructor to fail just because someone put an - # extra key in the metadata. + # zarr v2 allowed arbitrary keys in the metadata. + # Filter the keys to only those expected by the constructor. expected = {x.name for x in fields(cls)} - # https://github.com/zarr-developers/zarr-python/issues/2269 - # handle the renames expected |= {"dtype", "chunks"} # check if `filters` is an empty sequence; if so use None instead and raise a warning - if _data["filters"] is not None and len(_data["filters"]) == 0: + filters = _data.get("filters") + if ( + isinstance(filters, Sequence) + and not isinstance(filters, (str, bytes)) + and len(filters) == 0 + ): msg = ( "Found an empty list of filters in the array metadata document. " "This is contrary to the Zarr V2 specification, and will cause an error in the future. " @@ -196,13 +195,6 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: def to_dict(self) -> dict[str, JSON]: zarray_dict = super().to_dict() - if self.dtype.kind in "SV" and self.fill_value is not None: - # There's a relationship between self.dtype and self.fill_value - # that mypy isn't aware of. The fact that we have S or V dtype here - # means we should have a bytes-type fill_value. - fill_value = base64.standard_b64encode(cast(bytes, self.fill_value)).decode("ascii") - zarray_dict["fill_value"] = fill_value - _ = zarray_dict.pop("dtype") dtype_json: JSON # In the case of zarr v2, the simplest i.e., '|VXX' dtype is represented as a string @@ -300,7 +292,26 @@ def parse_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata: return data -def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any: +def _parse_structured_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: + """Handle structured dtype/fill value pairs""" + print("FILL VALUE", fill_value, "DT", dtype) + try: + if isinstance(fill_value, list): + return np.array([tuple(fill_value)], dtype=dtype)[0] + elif isinstance(fill_value, tuple): + return np.array([fill_value], dtype=dtype)[0] + elif isinstance(fill_value, bytes): + return np.frombuffer(fill_value, dtype=dtype)[0] + elif isinstance(fill_value, str): + decoded = base64.standard_b64decode(fill_value) + return np.frombuffer(decoded, dtype=dtype)[0] + else: + return np.array(fill_value, dtype=dtype)[()] + except Exception as e: + raise ValueError(f"Fill_value {fill_value} is not valid for dtype {dtype}.") from e + + +def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: """ Parse a potential fill value into a value that is compatible with the provided dtype. @@ -317,13 +328,16 @@ def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any: """ if fill_value is None or dtype.hasobject: - # no fill value pass + elif dtype.fields is not None: + # the dtype is structured (has multiple fields), so the fill_value might be a + # compound value (e.g., a tuple or dict) that needs field-wise processing. + # We use parse_structured_fill_value to correctly convert each component. + fill_value = _parse_structured_fill_value(fill_value, dtype) elif not isinstance(fill_value, np.void) and fill_value == 0: # this should be compatible across numpy versions for any array type, including # structured arrays fill_value = np.zeros((), dtype=dtype)[()] - elif dtype.kind == "U": # special case unicode because of encoding issues on Windows if passed through numpy # https://github.com/alimanfoo/zarr/pull/172#issuecomment-343782713 @@ -332,6 +346,11 @@ def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any: raise ValueError( f"fill_value {fill_value!r} is not valid for dtype {dtype}; must be a unicode string" ) + elif dtype.kind in "SV" and isinstance(fill_value, str): + fill_value = base64.standard_b64decode(fill_value) + elif dtype.kind == "c" and isinstance(fill_value, list) and len(fill_value) == 2: + complex_val = complex(float(fill_value[0]), float(fill_value[1])) + fill_value = np.array(complex_val, dtype=dtype)[()] else: try: if isinstance(fill_value, bytes) and dtype.kind == "V": @@ -347,6 +366,39 @@ def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any: return fill_value +def _serialize_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> JSON: + serialized: JSON + + if fill_value is None: + serialized = None + elif dtype.kind in "SV": + # There's a relationship between dtype and fill_value + # that mypy isn't aware of. The fact that we have S or V dtype here + # means we should have a bytes-type fill_value. + serialized = base64.standard_b64encode(cast(bytes, fill_value)).decode("ascii") + elif isinstance(fill_value, np.datetime64): + serialized = np.datetime_as_string(fill_value) + elif isinstance(fill_value, numbers.Integral): + serialized = int(fill_value) + elif isinstance(fill_value, numbers.Real): + float_fv = float(fill_value) + if np.isnan(float_fv): + serialized = "NaN" + elif np.isinf(float_fv): + serialized = "Infinity" if float_fv > 0 else "-Infinity" + else: + serialized = float_fv + elif isinstance(fill_value, numbers.Complex): + serialized = [ + _serialize_fill_value(fill_value.real, dtype), + _serialize_fill_value(fill_value.imag, dtype), + ] + else: + serialized = fill_value + + return serialized + + def _default_fill_value(dtype: np.dtype[Any]) -> Any: """ Get the default fill value for a type. diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index aa42329be7..f2dc38483a 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -5,7 +5,7 @@ import hypothesis.extra.numpy as npst import hypothesis.strategies as st import numpy as np -from hypothesis import event, given, settings # noqa: F401 +from hypothesis import event from hypothesis.strategies import SearchStrategy import zarr diff --git a/tests/test_properties.py b/tests/test_properties.py index 68d8bb0a0e..d48dfe2fef 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -1,3 +1,8 @@ +import dataclasses +import json +import numbers +from typing import Any + import numpy as np import pytest from numpy.testing import assert_array_equal @@ -8,9 +13,10 @@ import hypothesis.extra.numpy as npst import hypothesis.strategies as st -from hypothesis import assume, given +from hypothesis import assume, given, settings from zarr.abc.store import Store +from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.sync import sync from zarr.testing.strategies import ( @@ -25,8 +31,53 @@ ) +def deep_equal(a: Any, b: Any) -> bool: + """Deep equality check with handling of special cases for array metadata classes""" + if isinstance(a, (complex, np.complexfloating)) and isinstance( + b, (complex, np.complexfloating) + ): + a_real, a_imag = float(a.real), float(a.imag) + b_real, b_imag = float(b.real), float(b.imag) + if np.isnan(a_real) and np.isnan(b_real): + real_eq = True + else: + real_eq = a_real == b_real + if np.isnan(a_imag) and np.isnan(b_imag): + imag_eq = True + else: + imag_eq = a_imag == b_imag + return real_eq and imag_eq + + if isinstance(a, (float, np.floating)) and isinstance(b, (float, np.floating)): + if np.isnan(a) and np.isnan(b): + return True + return a == b + + if isinstance(a, np.datetime64) and isinstance(b, np.datetime64): + if np.isnat(a) and np.isnat(b): + return True + return a == b + + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + if a.shape != b.shape: + return False + return all(deep_equal(x, y) for x, y in zip(a.flat, b.flat, strict=False)) + + if isinstance(a, dict) and isinstance(b, dict): + if set(a.keys()) != set(b.keys()): + return False + return all(deep_equal(a[k], b[k]) for k in a) + + if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)): + if len(a) != len(b): + return False + return all(deep_equal(x, y) for x, y in zip(a, b, strict=False)) + + return a == b + + @given(data=st.data(), zarr_format=zarr_formats) -def test_roundtrip(data: st.DataObject, zarr_format: int) -> None: +def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None: nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format))) zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format))) assert_array_equal(nparray, zarray[:]) @@ -50,6 +101,8 @@ def test_array_creates_implicit_groups(array): ) +# this decorator removes timeout; not ideal but it should avoid intermittent CI failures +@settings(deadline=None) @given(data=st.data()) def test_basic_indexing(data: st.DataObject) -> None: zarray = data.draw(simple_arrays()) @@ -109,9 +162,17 @@ def test_vindex(data: st.DataObject) -> None: @given(store=stores, meta=array_metadata()) # type: ignore[misc] -async def test_roundtrip_array_metadata( +async def test_roundtrip_array_metadata_from_store( store: Store, meta: ArrayV2Metadata | ArrayV3Metadata ) -> None: + """ + Verify that the I/O for metadata in a store are lossless. + + This test serializes an ArrayV2Metadata or ArrayV3Metadata object to a dict + of buffers via `to_buffer_dict`, writes each buffer to a store under keys + prefixed with "0/", and then reads them back. The test asserts that each + retrieved buffer exactly matches the original buffer. + """ asdict = meta.to_buffer_dict(prototype=default_buffer_prototype()) for key, expected in asdict.items(): await store.set(f"0/{key}", expected) @@ -119,18 +180,39 @@ async def test_roundtrip_array_metadata( assert actual == expected -@given(store=stores, meta=array_metadata()) # type: ignore[misc] -def test_array_metadata_meets_spec(store: Store, meta: ArrayV2Metadata | ArrayV3Metadata) -> None: - # TODO: fill this out - asdict = meta.to_dict() - if isinstance(meta, ArrayV2Metadata): - assert asdict["filters"] != () - assert asdict["filters"] is None or isinstance(asdict["filters"], tuple) - assert asdict["zarr_format"] == 2 - elif isinstance(meta, ArrayV3Metadata): - assert asdict["zarr_format"] == 3 +@given(data=st.data(), zarr_format=zarr_formats) +def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: int) -> None: + """ + Verify that JSON serialization and deserialization of metadata is lossless. + + For Zarr v2: + - The metadata is split into two JSON documents (one for array data and one + for attributes). The test merges the attributes back before deserialization. + For Zarr v3: + - All metadata is stored in a single JSON document. No manual merger is necessary. + + The test then converts both the original and round-tripped metadata objects + into dictionaries using `dataclasses.asdict` and uses a deep equality check + to verify that the roundtrip has preserved all fields (including special + cases like NaN, Infinity, complex numbers, and datetime values). + """ + metadata = data.draw(array_metadata(zarr_formats=st.just(zarr_format))) + buffer_dict = metadata.to_buffer_dict(prototype=default_buffer_prototype()) + + if zarr_format == 2: + zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode()) + zattrs_dict = json.loads(buffer_dict[ZATTRS_JSON].to_bytes().decode()) + # zattrs and zarray are separate in v2, we have to add attributes back prior to `from_dict` + zarray_dict["attributes"] = zattrs_dict + metadata_roundtripped = ArrayV2Metadata.from_dict(zarray_dict) else: - raise NotImplementedError + zarray_dict = json.loads(buffer_dict[ZARR_JSON].to_bytes().decode()) + metadata_roundtripped = ArrayV3Metadata.from_dict(zarray_dict) + + orig = dataclasses.asdict(metadata) + rt = dataclasses.asdict(metadata_roundtripped) + + assert deep_equal(orig, rt), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}" # @st.composite @@ -155,3 +237,68 @@ def test_array_metadata_meets_spec(store: Store, meta: ArrayV2Metadata | ArrayV3 # nparray = data.draw(np_arrays) # zarray = data.draw(arrays(arrays=st.just(nparray))) # assert_array_equal(nparray, zarray[:]) + + +def serialized_float_is_valid(serialized: numbers.Real | str) -> bool: + """ + Validate that the serialized representation of a float conforms to the spec. + + The specification requires that a serialized float must be either: + - A JSON number, or + - One of the strings "NaN", "Infinity", or "-Infinity". + + Args: + serialized: The value produced by JSON serialization for a floating point number. + + Returns: + bool: True if the serialized value is valid according to the spec, False otherwise. + """ + if isinstance(serialized, numbers.Real): + return True + return serialized in ("NaN", "Infinity", "-Infinity") + + +@given(meta=array_metadata()) # type: ignore[misc] +def test_array_metadata_meets_spec(meta: ArrayV2Metadata | ArrayV3Metadata) -> None: + """ + Validate that the array metadata produced by the library conforms to the relevant spec (V2 vs V3). + + For ArrayV2Metadata: + - Ensures that 'zarr_format' is 2. + - Verifies that 'filters' is either None or a tuple (and not an empty tuple). + For ArrayV3Metadata: + - Ensures that 'zarr_format' is 3. + + For both versions: + - If the dtype is a floating point of some kind, verifies of fill values: + * NaN is serialized as the string "NaN" + * Positive Infinity is serialized as the string "Infinity" + * Negative Infinity is serialized as the string "-Infinity" + * Other fill values are preserved as-is. + - If the dtype is a complex number of some kind, verifies that each component of the fill + value (real and imaginary) satisfies the serialization rules for floating point numbers. + - If the dtype is a datetime of some kind, verifies that `NaT` values are serialized as "NaT". + + Note: + This test validates spec-compliance for array metadata serialization. + It is a work-in-progress and should be expanded as further edge cases are identified. + """ + asdict_dict = meta.to_dict() + + # version-specific validations + if isinstance(meta, ArrayV2Metadata): + assert asdict_dict["filters"] != () + assert asdict_dict["filters"] is None or isinstance(asdict_dict["filters"], tuple) + assert asdict_dict["zarr_format"] == 2 + else: + assert asdict_dict["zarr_format"] == 3 + + # version-agnostic validations + if meta.dtype.kind == "f": + assert serialized_float_is_valid(asdict_dict["fill_value"]) + elif meta.dtype.kind == "c": + # fill_value should be a two-element array [real, imag]. + assert serialized_float_is_valid(asdict_dict["fill_value"].real) + assert serialized_float_is_valid(asdict_dict["fill_value"].imag) + elif meta.dtype.kind == "M" and np.isnat(meta.fill_value): + assert asdict_dict["fill_value"] == "NaT" diff --git a/tests/test_v2.py b/tests/test_v2.py index 0a4487cfcc..3a36bc01fd 100644 --- a/tests/test_v2.py +++ b/tests/test_v2.py @@ -15,6 +15,7 @@ from zarr import config from zarr.abc.store import Store from zarr.core.buffer.core import default_buffer_prototype +from zarr.core.metadata.v2 import _parse_structured_fill_value from zarr.core.sync import sync from zarr.storage import MemoryStore, StorePath @@ -315,6 +316,89 @@ def test_structured_dtype_roundtrip(fill_value, tmp_path) -> None: assert (a == za[:]).all() +@pytest.mark.parametrize( + ( + "fill_value", + "dtype", + "expected_result", + ), + [ + ( + ("Alice", 30), + np.dtype([("name", "U10"), ("age", "i4")]), + np.array([("Alice", 30)], dtype=[("name", "U10"), ("age", "i4")])[0], + ), + ( + ["Bob", 25], + np.dtype([("name", "U10"), ("age", "i4")]), + np.array([("Bob", 25)], dtype=[("name", "U10"), ("age", "i4")])[0], + ), + ( + b"\x01\x00\x00\x00\x02\x00\x00\x00", + np.dtype([("x", "i4"), ("y", "i4")]), + np.array([(1, 2)], dtype=[("x", "i4"), ("y", "i4")])[0], + ), + ( + "BQAAAA==", + np.dtype([("val", "i4")]), + np.array([(5,)], dtype=[("val", "i4")])[0], + ), + ( + {"x": 1, "y": 2}, + np.dtype([("location", "O")]), + np.array([({"x": 1, "y": 2},)], dtype=[("location", "O")])[0], + ), + ( + {"x": 1, "y": 2, "z": 3}, + np.dtype([("location", "O")]), + np.array([({"x": 1, "y": 2, "z": 3},)], dtype=[("location", "O")])[0], + ), + ], + ids=[ + "tuple_input", + "list_input", + "bytes_input", + "string_input", + "dictionary_input", + "dictionary_input_extra_fields", + ], +) +def test_parse_structured_fill_value_valid( + fill_value: Any, dtype: np.dtype[Any], expected_result: Any +) -> None: + result = _parse_structured_fill_value(fill_value, dtype) + assert result.dtype == expected_result.dtype + assert result == expected_result + if isinstance(expected_result, np.void): + for name in expected_result.dtype.names or []: + assert result[name] == expected_result[name] + + +@pytest.mark.parametrize( + ( + "fill_value", + "dtype", + ), + [ + (("Alice", 30), np.dtype([("name", "U10"), ("age", "i4"), ("city", "U20")])), + (b"\x01\x00\x00\x00", np.dtype([("x", "i4"), ("y", "i4")])), + ("this_is_not_base64", np.dtype([("val", "i4")])), + ("hello", np.dtype([("age", "i4")])), + ({"x": 1, "y": 2}, np.dtype([("location", "i4")])), + ], + ids=[ + "tuple_list_wrong_length", + "bytes_wrong_length", + "invalid_base64", + "wrong_data_type", + "wrong_dictionary", + ], +) +def test_parse_structured_fill_value_invalid(fill_value: Any, dtype: np.dtype[Any]) -> None: + with pytest.raises(ValueError): + _parse_structured_fill_value(fill_value, dtype) + + @pytest.mark.parametrize("fill_value", [None, b"x"], ids=["no_fill", "fill"]) def test_other_dtype_roundtrip(fill_value, tmp_path) -> None: a = np.array([b"a\0\0", b"bb", b"ccc"], dtype="V7")