Skip to content

Commit bc588a7

Browse files
authored
Fix JSON encoding of complex fill values (#2432)
* Fix JSON encoding of complex fill values We were not replacing NaNs and Infs with the string versions. * Fix decoding of complex fill values * try excluding `math.inf` * Check complex numbers explicitly * Update src/zarr/core/metadata/v3.py
1 parent 6ce0526 commit bc588a7

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

src/zarr/core/metadata/v3.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@
4343

4444
DEFAULT_DTYPE = "float64"
4545

46+
# Keep in sync with _replace_special_floats
47+
SPECIAL_FLOATS_ENCODED = {
48+
"Infinity": np.inf,
49+
"-Infinity": -np.inf,
50+
"NaN": np.nan,
51+
}
52+
4653

4754
def parse_zarr_format(data: object) -> Literal[3]:
4855
if data == 3:
@@ -149,7 +156,7 @@ def default(self, o: object) -> Any:
149156
if isinstance(out, complex):
150157
# python complex types are not JSON serializable, so we use the
151158
# serialization defined in the zarr v3 spec
152-
return [out.real, out.imag]
159+
return _replace_special_floats([out.real, out.imag])
153160
elif np.isnan(out):
154161
return "NaN"
155162
elif np.isinf(out):
@@ -447,8 +454,11 @@ def parse_fill_value(
447454
if isinstance(fill_value, Sequence) and not isinstance(fill_value, str):
448455
if data_type in (DataType.complex64, DataType.complex128):
449456
if len(fill_value) == 2:
457+
decoded_fill_value = tuple(
458+
SPECIAL_FLOATS_ENCODED.get(value, value) for value in fill_value
459+
)
450460
# complex datatypes serialize to JSON arrays with two elements
451-
return np_dtype.type(complex(*fill_value))
461+
return np_dtype.type(complex(*decoded_fill_value))
452462
else:
453463
msg = (
454464
f"Got an invalid fill value for complex data type {data_type.value}."
@@ -475,12 +485,20 @@ def parse_fill_value(
475485
pass
476486
elif fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value):
477487
pass
478-
elif np_dtype.kind in "cf":
488+
elif np_dtype.kind == "f":
479489
# float comparison is not exact, especially when dtype <float64
480-
# so we us np.isclose for this comparison.
490+
# so we use np.isclose for this comparison.
481491
# this also allows us to compare nan fill_values
482492
if not np.isclose(fill_value, casted_value, equal_nan=True):
483493
raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}")
494+
elif np_dtype.kind == "c":
495+
# confusingly np.isclose(np.inf, np.inf + 0j) is False on numpy<2, so compare real and imag parts
496+
# explicitly.
497+
if not (
498+
np.isclose(np.real(fill_value), np.real(casted_value), equal_nan=True)
499+
and np.isclose(np.imag(fill_value), np.imag(casted_value), equal_nan=True)
500+
):
501+
raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}")
484502
else:
485503
if fill_value != casted_value:
486504
raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}")

tests/test_array.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
import math
13
import pickle
24
from itertools import accumulate
35
from typing import Any, Literal
@@ -9,6 +11,7 @@
911
from zarr import Array, AsyncArray, Group
1012
from zarr.codecs import BytesCodec, VLenBytesCodec
1113
from zarr.core.array import chunks_initialized
14+
from zarr.core.buffer import default_buffer_prototype
1215
from zarr.core.buffer.cpu import NDBuffer
1316
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
1417
from zarr.core.group import AsyncGroup
@@ -624,3 +627,23 @@ def test_array_create_order(
624627
assert vals.flags.f_contiguous
625628
else:
626629
raise AssertionError
630+
631+
632+
@pytest.mark.parametrize(
633+
("fill_value", "expected"),
634+
[
635+
(np.nan * 1j, ["NaN", "NaN"]),
636+
(np.nan, ["NaN", 0.0]),
637+
(np.inf, ["Infinity", 0.0]),
638+
(np.inf * 1j, ["NaN", "Infinity"]),
639+
(-np.inf, ["-Infinity", 0.0]),
640+
(math.inf, ["Infinity", 0.0]),
641+
],
642+
)
643+
async def test_special_complex_fill_values_roundtrip(fill_value: Any, expected: list[Any]) -> None:
644+
store = MemoryStore({}, mode="w")
645+
Array.create(store=store, shape=(1,), dtype=np.complex64, fill_value=fill_value)
646+
content = await store.get("zarr.json", prototype=default_buffer_prototype())
647+
assert content is not None
648+
actual = json.loads(content.to_bytes())
649+
assert actual["fill_value"] == expected

0 commit comments

Comments
 (0)