Skip to content

Commit b726e0d

Browse files
committed
Add numpy encoder class for json.dumps and add test
1 parent 22ded1d commit b726e0d

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

zarr/tests/test_creation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,3 +540,8 @@ def test_create_read_only():
540540
assert z.read_only
541541
with pytest.raises(PermissionError):
542542
z[:] = 42
543+
544+
545+
def test_json_dumps_chunks_numpy_dtype():
546+
z = zeros((10,), chunks=(np.int64(2),))
547+
assert np.all(z[...] == 0)

zarr/tests/test_util.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import numpy as np
55
import pytest
66

7-
from zarr.util import (all_equal, flatten, guess_chunks, human_readable_size, info_html_report,
8-
info_text_report, is_total_slice, normalize_chunks,
7+
from zarr.util import (all_equal, flatten, guess_chunks, human_readable_size,
8+
info_html_report, info_text_report, is_total_slice,
9+
json_dumps, normalize_chunks,
910
normalize_dimension_separator,
1011
normalize_fill_value, normalize_order,
1112
normalize_resize_args, normalize_shape, retry_call,
@@ -238,3 +239,7 @@ def test_all_equal():
238239
# all_equal(None, *) always returns False
239240
assert not all_equal(None, np.array([None, None]))
240241
assert not all_equal(None, np.array([None, 10]))
242+
243+
244+
def test_json_dumps_numpy_dtype():
245+
assert json_dumps(np.int64(0)) == json_dumps(0)

zarr/util.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,23 @@ def flatten(arg: Iterable) -> Iterable:
3333
}
3434

3535

36+
class NumpyEncoder(json.JSONEncoder):
37+
38+
def default(self, o):
39+
# See https://github.com/python/cpython/blob/e0ec08dc49f8e6f94a735bc9946ef7a3fd898a44/Lib/json/encoder.py#L160
40+
if isinstance(o, np.integer):
41+
return int(o)
42+
if isinstance(o, np.floating):
43+
return float(o)
44+
if isinstance(o, np.ndarray):
45+
return o.tolist()
46+
return json.JSONEncoder.default(self, o)
47+
48+
3649
def json_dumps(o: Any) -> bytes:
3750
"""Write JSON in a consistent, human-readable way."""
3851
return json.dumps(o, indent=4, sort_keys=True, ensure_ascii=True,
39-
separators=(',', ': ')).encode('ascii')
52+
separators=(',', ': '), cls=NumpyEncoder).encode('ascii')
4053

4154

4255
def json_loads(s: str) -> Dict[str, Any]:

0 commit comments

Comments
 (0)