Skip to content

Commit 6301b15

Browse files
committed
Round trip serialization for array metadata v2/v3
1 parent fd43cbf commit 6301b15

File tree

5 files changed

+124
-200
lines changed

5 files changed

+124
-200
lines changed

changes/2802.fix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix `fill_value` serialization for `NaN` in `ArrayV2Metadata` and add property-based testing of round-trip serialization

src/zarr/core/metadata/v2.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
170170
if dtype.kind in "SV":
171171
fill_value_encoded = _data.get("fill_value")
172172
if fill_value_encoded is not None:
173-
fill_value = base64.standard_b64decode(fill_value_encoded)
173+
fill_value: Any = base64.standard_b64decode(fill_value_encoded)
174174
_data["fill_value"] = fill_value
175175
else:
176176
fill_value = _data.get("fill_value")
@@ -180,13 +180,11 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
180180
_data["fill_value"] = np.array("NaT", dtype=dtype)[()]
181181
else:
182182
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
183-
elif dtype.kind == "c" and isinstance(fill_value, list):
184-
if len(fill_value) == 2:
185-
val = complex(float(fill_value[0]), float(fill_value[1]))
186-
_data["fill_value"] = np.array(val, dtype=dtype)[()]
187-
elif dtype.kind in "f" and isinstance(fill_value, str):
188-
if fill_value in {"NaN", "Infinity", "-Infinity"}:
189-
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
183+
elif dtype.kind == "c" and isinstance(fill_value, list) and len(fill_value) == 2:
184+
val = complex(float(fill_value[0]), float(fill_value[1]))
185+
_data["fill_value"] = np.array(val, dtype=dtype)[()]
186+
elif dtype.kind in "f" and fill_value in {"NaN", "Infinity", "-Infinity"}:
187+
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
190188
# zarr v2 allowed arbitrary keys in the metadata.
191189
# Filter the keys to only those expected by the constructor.
192190
expected = {x.name for x in fields(cls)}
@@ -196,21 +194,22 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
196194
return cls(**_data)
197195

198196
def to_dict(self) -> dict[str, JSON]:
199-
def _sanitize_fill_value(fv: Any):
197+
def _sanitize_fill_value(fv: Any) -> JSON:
200198
if fv is None:
201199
return fv
202200
elif isinstance(fv, np.datetime64):
203201
if np.isnat(fv):
204202
return "NaT"
205203
return np.datetime_as_string(fv)
206204
elif isinstance(fv, numbers.Real):
207-
if np.isnan(fv):
205+
float_fv = float(fv)
206+
if np.isnan(float_fv):
208207
fv = "NaN"
209-
elif np.isinf(fv):
210-
fv = "Infinity" if fv > 0 else "-Infinity"
208+
elif np.isinf(float_fv):
209+
fv = "Infinity" if float_fv > 0 else "-Infinity"
211210
elif isinstance(fv, numbers.Complex):
212211
fv = [_sanitize_fill_value(fv.real), _sanitize_fill_value(fv.imag)]
213-
return fv
212+
return cast(JSON, fv)
214213

215214
zarray_dict = super().to_dict()
216215

src/zarr/testing/stateful.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def add_group(self, name: str, data: DataObject) -> None:
8585
@rule(
8686
data=st.data(),
8787
name=node_names,
88-
array_and_chunks=np_array_and_chunks(nparrays=numpy_arrays(zarr_formats=st.just(3))),
88+
array_and_chunks=np_array_and_chunks(arrays=numpy_arrays(zarr_formats=st.just(3))),
8989
)
9090
def add_array(
9191
self,

src/zarr/testing/strategies.py

Lines changed: 0 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import hypothesis.extra.numpy as npst
55
import hypothesis.strategies as st
6-
import numcodecs
76
import numpy as np
87
from hypothesis import assume, given, settings # noqa: F401
98
from hypothesis.strategies import SearchStrategy
@@ -345,136 +344,3 @@ def make_request(start: int, length: int) -> RangeByteRequest:
345344
)
346345
key_tuple = st.tuples(keys, byte_ranges)
347346
return st.lists(key_tuple, min_size=1, max_size=10)
348-
349-
350-
def simple_text():
351-
"""A strategy for generating simple text strings."""
352-
return st.text(st.characters(min_codepoint=32, max_codepoint=126), min_size=1, max_size=10)
353-
354-
355-
def simple_attrs():
356-
"""A strategy for generating simple attribute dictionaries."""
357-
return st.dictionaries(
358-
simple_text(),
359-
st.one_of(
360-
st.integers(),
361-
st.floats(allow_nan=False, allow_infinity=False),
362-
st.booleans(),
363-
simple_text(),
364-
),
365-
)
366-
367-
368-
def array_shapes(min_dims=1, max_dims=3, max_len=100):
369-
"""A strategy for generating array shapes."""
370-
return st.lists(
371-
st.integers(min_value=1, max_value=max_len), min_size=min_dims, max_size=max_dims
372-
)
373-
374-
375-
# def zarr_compressors():
376-
# """A strategy for generating Zarr compressors."""
377-
# return st.sampled_from([None, Blosc(), GZip(), Zstd(), LZ4()])
378-
379-
380-
# def zarr_codecs():
381-
# """A strategy for generating Zarr codecs."""
382-
# return st.sampled_from([BytesCodec(), Blosc(), GZip(), Zstd(), LZ4()])
383-
384-
385-
def zarr_filters():
386-
"""A strategy for generating Zarr filters."""
387-
return st.lists(
388-
st.just(numcodecs.Delta(dtype="i4")), min_size=0, max_size=2
389-
) # Example filter, expand as needed
390-
391-
392-
def zarr_storage_transformers():
393-
"""A strategy for generating Zarr storage transformers."""
394-
return st.lists(
395-
st.dictionaries(
396-
simple_text(), st.one_of(st.integers(), st.floats(), st.booleans(), simple_text())
397-
),
398-
min_size=0,
399-
max_size=2,
400-
)
401-
402-
403-
@st.composite
404-
def array_metadata_v2(draw: st.DrawFn) -> ArrayV2Metadata:
405-
"""Generates valid ArrayV2Metadata objects for property-based testing."""
406-
dims = draw(st.integers(min_value=1, max_value=3)) # Limit dimensions for complexity
407-
shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100)))
408-
max_chunk_len = max(shape) if shape else 100
409-
chunks = tuple(
410-
draw(
411-
st.lists(
412-
st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims
413-
)
414-
)
415-
)
416-
417-
# Validate shape and chunks relationship
418-
assume(all(c <= s for s, c in zip(shape, chunks, strict=False))) # Chunk size must be <= shape
419-
420-
dtype = draw(v2_dtypes())
421-
fill_value = draw(st.one_of([st.none(), npst.from_dtype(dtype)]))
422-
order = draw(st.sampled_from(["C", "F"]))
423-
dimension_separator = draw(st.sampled_from([".", "/"]))
424-
# compressor = draw(zarr_compressors())
425-
filters = tuple(draw(zarr_filters())) if draw(st.booleans()) else None
426-
attributes = draw(simple_attrs())
427-
428-
# Construct the metadata object. Type hints are crucial here for correctness.
429-
return ArrayV2Metadata(
430-
shape=shape,
431-
dtype=dtype,
432-
chunks=chunks,
433-
fill_value=fill_value,
434-
order=order,
435-
dimension_separator=dimension_separator,
436-
# compressor=compressor,
437-
filters=filters,
438-
attributes=attributes,
439-
)
440-
441-
442-
@st.composite
443-
def array_metadata_v3(draw: st.DrawFn) -> ArrayV3Metadata:
444-
"""Generates valid ArrayV3Metadata objects for property-based testing."""
445-
dims = draw(st.integers(min_value=1, max_value=3))
446-
shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100)))
447-
max_chunk_len = max(shape) if shape else 100
448-
chunks = tuple(
449-
draw(
450-
st.lists(
451-
st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims
452-
)
453-
)
454-
)
455-
assume(all(c <= s for s, c in zip(shape, chunks, strict=False)))
456-
457-
dtype = draw(v3_dtypes())
458-
fill_value = draw(npst.from_dtype(dtype))
459-
chunk_grid = RegularChunkGrid(chunks) # Ensure chunks is passed as tuple.
460-
chunk_key_encoding = DefaultChunkKeyEncoding(separator="/") # Or st.sampled_from(["/", "."])
461-
# codecs = tuple(draw(st.lists(zarr_codecs(), min_size=0, max_size=3)))
462-
attributes = draw(simple_attrs())
463-
dimension_names = (
464-
tuple(draw(st.lists(st.one_of(st.none(), simple_text()), min_size=dims, max_size=dims)))
465-
if draw(st.booleans())
466-
else None
467-
)
468-
storage_transformers = tuple(draw(zarr_storage_transformers()))
469-
470-
return ArrayV3Metadata(
471-
shape=shape,
472-
data_type=dtype,
473-
chunk_grid=chunk_grid,
474-
chunk_key_encoding=chunk_key_encoding,
475-
fill_value=fill_value,
476-
# codecs=codecs,
477-
attributes=attributes,
478-
dimension_names=dimension_names,
479-
storage_transformers=storage_transformers,
480-
)

0 commit comments

Comments
 (0)