Skip to content

Commit 042d815

Browse files
committed
Remove redundancies, fix integral handling
1 parent a59f9ac commit 042d815

File tree

3 files changed

+51
-42
lines changed

3 files changed

+51
-42
lines changed

src/zarr/core/metadata/v2.py

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import base64
44
import warnings
5-
from collections.abc import Iterable, Mapping, Sequence
5+
from collections.abc import Iterable, Sequence
66
from enum import Enum
77
from functools import cached_property
88
from typing import TYPE_CHECKING, Any, TypedDict, cast
@@ -109,29 +109,6 @@ def shards(self) -> ChunkCoords | None:
109109
return None
110110

111111
def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]:
112-
def _serialize_fill_value(fv: Any) -> JSON:
113-
if self.fill_value is None:
114-
pass
115-
elif self.dtype.kind in "SV":
116-
# There's a relationship between self.dtype and self.fill_value
117-
# that mypy isn't aware of. The fact that we have S or V dtype here
118-
# means we should have a bytes-type fill_value.
119-
fv = base64.standard_b64encode(cast(bytes, self.fill_value)).decode("ascii")
120-
elif isinstance(fv, np.datetime64):
121-
if np.isnat(fv):
122-
fv = "NaT"
123-
else:
124-
fv = np.datetime_as_string(fv)
125-
elif isinstance(fv, numbers.Real):
126-
float_fv = float(fv)
127-
if np.isnan(float_fv):
128-
fv = "NaN"
129-
elif np.isinf(float_fv):
130-
fv = "Infinity" if float_fv > 0 else "-Infinity"
131-
elif isinstance(fv, numbers.Complex):
132-
fv = [_serialize_fill_value(fv.real), _serialize_fill_value(fv.imag)]
133-
return cast(JSON, fv)
134-
135112
def _json_convert(
136113
o: Any,
137114
) -> Any:
@@ -170,7 +147,7 @@ def _json_convert(
170147
raise TypeError
171148

172149
zarray_dict = self.to_dict()
173-
zarray_dict["fill_value"] = _serialize_fill_value(zarray_dict["fill_value"])
150+
zarray_dict["fill_value"] = _serialize_fill_value(self.fill_value, self.dtype)
174151
zattrs_dict = zarray_dict.pop("attributes", {})
175152
json_indent = config.get("json_indent")
176153
return {
@@ -185,7 +162,7 @@ def _json_convert(
185162
}
186163

187164
@classmethod
188-
def from_dict(cls, data: dict[str, JSON]) -> ArrayV2Metadata:
165+
def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
189166
# Make a copy to protect the original from modification.
190167
_data = data.copy()
191168
# Check that the zarr_format attribute is correct.
@@ -213,7 +190,7 @@ def from_dict(cls, data: dict[str, JSON]) -> ArrayV2Metadata:
213190

214191
_data = {k: v for k, v in _data.items() if k in expected}
215192

216-
return cls(**cast(Mapping[str, Any], _data))
193+
return cls(**_data)
217194

218195
def to_dict(self) -> dict[str, JSON]:
219196
zarray_dict = super().to_dict()
@@ -315,7 +292,7 @@ def parse_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata:
315292
return data
316293

317294

318-
def parse_structured_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
295+
def _parse_structured_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
319296
"""Handle structured dtype/fill value pairs"""
320297
try:
321298
if isinstance(fill_value, list):
@@ -354,7 +331,10 @@ def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
354331
if fill_value is None or dtype.hasobject:
355332
pass
356333
elif dtype.fields is not None:
357-
fill_value = parse_structured_fill_value(fill_value, dtype)
334+
# the dtype is structured (has multiple fields), so the fill_value might be a
335+
# compound value (e.g., a tuple or dict) that needs field-wise processing.
336+
# We use parse_structured_fill_value to correctly convert each component.
337+
fill_value = _parse_structured_fill_value(fill_value, dtype)
358338
elif not isinstance(fill_value, np.void) and fill_value == 0:
359339
# this should be compatible across numpy versions for any array type, including
360340
# structured arrays
@@ -369,16 +349,9 @@ def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
369349
)
370350
elif dtype.kind in "SV" and isinstance(fill_value, str):
371351
fill_value = base64.standard_b64decode(fill_value)
372-
elif np.issubdtype(dtype, np.datetime64):
373-
if fill_value == "NaT":
374-
fill_value = np.array("NaT", dtype=dtype)[()]
375-
else:
376-
fill_value = np.array(fill_value, dtype=dtype)[()]
377352
elif dtype.kind == "c" and isinstance(fill_value, list) and len(fill_value) == 2:
378353
complex_val = complex(float(fill_value[0]), float(fill_value[1]))
379354
fill_value = np.array(complex_val, dtype=dtype)[()]
380-
elif dtype.kind in "f" and fill_value in {"NaN", "Infinity", "-Infinity"}:
381-
fill_value = np.array(fill_value, dtype=dtype)[()]
382355
else:
383356
try:
384357
if isinstance(fill_value, bytes) and dtype.kind == "V":
@@ -394,6 +367,39 @@ def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
394367
return fill_value
395368

396369

370+
def _serialize_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> JSON:
371+
serialized: JSON
372+
373+
if fill_value is None:
374+
serialized = None
375+
elif dtype.kind in "SV":
376+
# There's a relationship between dtype and fill_value
377+
# that mypy isn't aware of. The fact that we have S or V dtype here
378+
# means we should have a bytes-type fill_value.
379+
serialized = base64.standard_b64encode(cast(bytes, fill_value)).decode("ascii")
380+
elif isinstance(fill_value, np.datetime64):
381+
serialized = np.datetime_as_string(fill_value)
382+
elif isinstance(fill_value, numbers.Integral):
383+
serialized = int(fill_value)
384+
elif isinstance(fill_value, numbers.Real):
385+
float_fv = float(fill_value)
386+
if np.isnan(float_fv):
387+
serialized = "NaN"
388+
elif np.isinf(float_fv):
389+
serialized = "Infinity" if float_fv > 0 else "-Infinity"
390+
else:
391+
serialized = float_fv
392+
elif isinstance(fill_value, numbers.Complex):
393+
serialized = [
394+
_serialize_fill_value(fill_value.real, dtype),
395+
_serialize_fill_value(fill_value.imag, dtype),
396+
]
397+
else:
398+
serialized = fill_value
399+
400+
return serialized
401+
402+
397403
def _default_fill_value(dtype: np.dtype[Any]) -> Any:
398404
"""
399405
Get the default fill value for a type.

tests/test_properties.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import dataclasses
22
import json
33
import numbers
4+
from typing import Any
45

56
import numpy as np
67
import pytest
@@ -12,7 +13,7 @@
1213

1314
import hypothesis.extra.numpy as npst
1415
import hypothesis.strategies as st
15-
from hypothesis import assume, given
16+
from hypothesis import assume, given, settings
1617

1718
from zarr.abc.store import Store
1819
from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON
@@ -30,7 +31,7 @@
3031
)
3132

3233

33-
def deep_equal(a, b):
34+
def deep_equal(a: Any, b: Any) -> bool:
3435
"""Deep equality check with handling of special cases for array metadata classes"""
3536
if isinstance(a, (complex, np.complexfloating)) and isinstance(
3637
b, (complex, np.complexfloating)
@@ -100,6 +101,8 @@ def test_array_creates_implicit_groups(array):
100101
)
101102

102103

104+
# bump deadline from 200 to 300 to avoid (rare) intermittent timeouts
105+
@settings(deadline=300)
103106
@given(data=st.data())
104107
def test_basic_indexing(data: st.DataObject) -> None:
105108
zarray = data.draw(simple_arrays())
@@ -236,7 +239,7 @@ def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: in
236239
# assert_array_equal(nparray, zarray[:])
237240

238241

239-
def serialized_float_is_valid(serialized):
242+
def serialized_float_is_valid(serialized: numbers.Real | str) -> bool:
240243
"""
241244
Validate that the serialized representation of a float conforms to the spec.
242245

tests/test_v2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from zarr import config
1616
from zarr.abc.store import Store
1717
from zarr.core.buffer.core import default_buffer_prototype
18-
from zarr.core.metadata.v2 import parse_structured_fill_value
18+
from zarr.core.metadata.v2 import _parse_structured_fill_value
1919
from zarr.core.sync import sync
2020
from zarr.storage import MemoryStore, StorePath
2121

@@ -366,7 +366,7 @@ def test_structured_dtype_roundtrip(fill_value, tmp_path) -> None:
366366
def test_parse_structured_fill_value_valid(
367367
fill_value: Any, dtype: np.dtype[Any], expected_result: Any
368368
) -> None:
369-
result = parse_structured_fill_value(fill_value, dtype)
369+
result = _parse_structured_fill_value(fill_value, dtype)
370370
assert result.dtype == expected_result.dtype
371371
assert result == expected_result
372372
if isinstance(expected_result, np.void):
@@ -396,7 +396,7 @@ def test_parse_structured_fill_value_valid(
396396
)
397397
def test_parse_structured_fill_value_invalid(fill_value: Any, dtype: np.dtype[Any]) -> None:
398398
with pytest.raises(ValueError):
399-
parse_structured_fill_value(fill_value, dtype)
399+
_parse_structured_fill_value(fill_value, dtype)
400400

401401

402402
@pytest.mark.parametrize("fill_value", [None, b"x"], ids=["no_fill", "fill"])

0 commit comments

Comments
 (0)