diff --git a/docs/release.rst b/docs/release.rst index 13c2f20d2c..7a5cf51db7 100644 --- a/docs/release.rst +++ b/docs/release.rst @@ -14,6 +14,9 @@ Bug fixes value when empty chunks are read back in. By :user:`Vyas Ramasubramani `; :issue:`965`. +* Add number encoder for ``json.dumps`` to support numpy intergers in + ``chunks`` arguments. By :user:`Eric Prestat ` :issue:`697`. + .. _release_2.11.1: 2.11.1 diff --git a/zarr/tests/test_creation.py b/zarr/tests/test_creation.py index cfab4f79ec..ee99bc7c9f 100644 --- a/zarr/tests/test_creation.py +++ b/zarr/tests/test_creation.py @@ -714,3 +714,8 @@ def test_create_read_only(zarr_version): assert z.read_only with pytest.raises(PermissionError): z[:] = 42 + + +def test_json_dumps_chunks_numpy_dtype(): + z = zeros((10,), chunks=(np.int64(2),)) + assert np.all(z[...] == 0) diff --git a/zarr/tests/test_util.py b/zarr/tests/test_util.py index efe8e66341..e9e1786abe 100644 --- a/zarr/tests/test_util.py +++ b/zarr/tests/test_util.py @@ -4,8 +4,10 @@ import numpy as np import pytest -from zarr.util import (all_equal, flatten, guess_chunks, human_readable_size, info_html_report, - info_text_report, is_total_slice, normalize_chunks, +from zarr.core import Array +from zarr.util import (all_equal, flatten, guess_chunks, human_readable_size, + info_html_report, info_text_report, is_total_slice, + json_dumps, normalize_chunks, normalize_dimension_separator, normalize_fill_value, normalize_order, normalize_resize_args, normalize_shape, retry_call, @@ -238,3 +240,11 @@ def test_all_equal(): # all_equal(None, *) always returns False assert not all_equal(None, np.array([None, None])) assert not all_equal(None, np.array([None, 10])) + + +def test_json_dumps_numpy_dtype(): + assert json_dumps(np.int64(0)) == json_dumps(0) + assert json_dumps(np.float32(0)) == json_dumps(float(0)) + # Check that we raise the error of the superclass for unsupported object + with pytest.raises(TypeError): + json_dumps(Array) diff --git a/zarr/util.py b/zarr/util.py index 9f5f04f525..cc3bd50356 100644 --- a/zarr/util.py +++ b/zarr/util.py @@ -33,10 +33,22 @@ def flatten(arg: Iterable) -> Iterable: } +class NumberEncoder(json.JSONEncoder): + + def default(self, o): + # See json.JSONEncoder.default docstring for explanation + # This is necessary to encode numpy dtype + if isinstance(o, numbers.Integral): + return int(o) + if isinstance(o, numbers.Real): + return float(o) + return json.JSONEncoder.default(self, o) + + def json_dumps(o: Any) -> bytes: """Write JSON in a consistent, human-readable way.""" return json.dumps(o, indent=4, sort_keys=True, ensure_ascii=True, - separators=(',', ': ')).encode('ascii') + separators=(',', ': '), cls=NumberEncoder).encode('ascii') def json_loads(s: str) -> Dict[str, Any]: