Skip to content

Commit 1087178

Browse files
committed
Add v2, v3 specific dtypes
1 parent 6db8225 commit 1087178

File tree

4 files changed

+63
-16
lines changed

4 files changed

+63
-16
lines changed

src/zarr/core/buffer/core.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,9 @@ def all_equal(self, other: Any, equal_nan: bool = True) -> bool:
469469
return False
470470
# use array_equal to obtain equal_nan=True functionality
471471
data, other = np.broadcast_arrays(self._data, other)
472-
result = np.array_equal(self._data, other, equal_nan=equal_nan)
472+
result = np.array_equal(
473+
self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "US" else False
474+
)
473475
return result
474476

475477
def fill(self, value: Any) -> None:

src/zarr/testing/strategies.py

+39-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Literal
22

33
import hypothesis.extra.numpy as npst
44
import hypothesis.strategies as st
@@ -19,19 +19,34 @@
1919
)
2020

2121

22-
def dtypes() -> st.SearchStrategy[np.dtype]:
22+
def v3_dtypes() -> st.SearchStrategy[np.dtype]:
2323
return (
2424
npst.boolean_dtypes()
2525
| npst.integer_dtypes(endianness="=")
2626
| npst.unsigned_integer_dtypes(endianness="=")
2727
| npst.floating_dtypes(endianness="=")
2828
| npst.complex_number_dtypes(endianness="=")
29+
# | npst.byte_string_dtypes(endianness="=")
2930
# | npst.unicode_string_dtypes()
3031
# | npst.datetime64_dtypes()
3132
# | npst.timedelta64_dtypes()
3233
)
3334

3435

36+
def v2_dtypes() -> st.SearchStrategy[np.dtype]:
37+
return (
38+
npst.boolean_dtypes()
39+
| npst.integer_dtypes(endianness="=")
40+
| npst.unsigned_integer_dtypes(endianness="=")
41+
| npst.floating_dtypes(endianness="=")
42+
| npst.complex_number_dtypes(endianness="=")
43+
| npst.byte_string_dtypes(endianness="=")
44+
| npst.unicode_string_dtypes(endianness="=")
45+
| npst.datetime64_dtypes()
46+
# | npst.timedelta64_dtypes()
47+
)
48+
49+
3550
# From https://zarr-specs.readthedocs.io/en/latest/v3/core/v3.0.html#node-names
3651
# 1. must not be the empty string ("")
3752
# 2. must not include the character "/"
@@ -46,18 +61,29 @@ def dtypes() -> st.SearchStrategy[np.dtype]:
4661
array_names = node_names
4762
attrs = st.none() | st.dictionaries(_attr_keys, _attr_values)
4863
paths = st.lists(node_names, min_size=1).map(lambda x: "/".join(x)) | st.just("/")
49-
np_arrays = npst.arrays(
50-
dtype=dtypes(),
51-
shape=npst.array_shapes(max_dims=4),
52-
)
5364
stores = st.builds(MemoryStore, st.just({}), mode=st.just("w"))
5465
compressors = st.sampled_from([None, "default"])
55-
zarr_formats = st.sampled_from([2, 3])
66+
zarr_formats: st.SearchStrategy[Literal[2, 3]] = st.sampled_from([2, 3])
67+
array_shapes = npst.array_shapes(max_dims=4)
68+
69+
70+
@st.composite # type: ignore[misc]
71+
def numpy_arrays(
72+
draw: st.DrawFn,
73+
*,
74+
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
75+
zarr_formats: st.SearchStrategy[Literal[2, 3]] = zarr_formats,
76+
) -> Any:
77+
"""
78+
Generate numpy arrays that can be saved in the provided Zarr format.
79+
"""
80+
zarr_format = draw(zarr_formats)
81+
return draw(npst.arrays(dtype=v3_dtypes() if zarr_format == 3 else v2_dtypes(), shape=shapes))
5682

5783

5884
@st.composite # type: ignore[misc]
5985
def np_array_and_chunks(
60-
draw: st.DrawFn, *, arrays: st.SearchStrategy[np.ndarray] = np_arrays
86+
draw: st.DrawFn, *, arrays: st.SearchStrategy[np.ndarray] = numpy_arrays
6187
) -> tuple[np.ndarray, tuple[int]]: # type: ignore[type-arg]
6288
"""A hypothesis strategy to generate small sized random arrays.
6389
@@ -76,20 +102,23 @@ def np_array_and_chunks(
76102
def arrays(
77103
draw: st.DrawFn,
78104
*,
105+
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
79106
compressors: st.SearchStrategy = compressors,
80107
stores: st.SearchStrategy[StoreLike] = stores,
81-
arrays: st.SearchStrategy[np.ndarray] = np_arrays,
82108
paths: st.SearchStrategy[None | str] = paths,
83109
array_names: st.SearchStrategy = array_names,
110+
arrays: st.SearchStrategy | None = None,
84111
attrs: st.SearchStrategy = attrs,
85112
zarr_formats: st.SearchStrategy = zarr_formats,
86113
) -> Array:
87114
store = draw(stores)
88-
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
89115
path = draw(paths)
90116
name = draw(array_names)
91117
attributes = draw(attrs)
92118
zarr_format = draw(zarr_formats)
119+
if arrays is None:
120+
arrays = numpy_arrays(shapes=shapes, zarr_formats=st.just(zarr_format))
121+
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
93122
# test that None works too.
94123
fill_value = draw(st.one_of([st.none(), npst.from_dtype(nparray.dtype)]))
95124
# compressor = draw(compressors)

tests/v3/test_properties.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
import hypothesis.extra.numpy as npst # noqa
88
import hypothesis.strategies as st # noqa
99
from hypothesis import given, settings # noqa
10-
from zarr.testing.strategies import arrays, np_arrays, basic_indices # noqa
10+
from zarr.testing.strategies import arrays, numpy_arrays, basic_indices, zarr_formats # noqa
1111

1212

13-
@given(st.data())
14-
def test_roundtrip(data: st.DataObject) -> None:
15-
nparray = data.draw(np_arrays)
16-
zarray = data.draw(arrays(arrays=st.just(nparray)))
13+
@given(data=st.data(), zarr_format=zarr_formats)
14+
def test_roundtrip(data: st.DataObject, zarr_format: int) -> None:
15+
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
16+
zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format)))
1717
assert_array_equal(nparray, zarray[:])
1818

1919

@@ -31,6 +31,7 @@ def test_basic_indexing(data: st.DataObject) -> None:
3131
assert_array_equal(nparray, zarray[:])
3232

3333

34+
@settings(report_multiple_bugs=False)
3435
@given(data=st.data())
3536
def test_vindex(data: st.DataObject) -> None:
3637
zarray = data.draw(arrays())

zarr/version.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# file generated by setuptools_scm
2+
# don't change, don't track in version control
3+
TYPE_CHECKING = False
4+
if TYPE_CHECKING:
5+
VERSION_TUPLE = tuple[int | str, ...]
6+
else:
7+
VERSION_TUPLE = object
8+
9+
version: str
10+
__version__: str
11+
__version_tuple__: VERSION_TUPLE
12+
version_tuple: VERSION_TUPLE
13+
14+
__version__ = version = "2.18.2"
15+
__version_tuple__ = version_tuple = (2, 18, 2)

0 commit comments

Comments
 (0)