Skip to content

Commit c6a27a2

Browse files
committed
test: add property tests for block and mask indexing
1 parent b9d3964 commit c6a27a2

3 files changed

Lines changed: 105 additions & 0 deletions

File tree

changes/4054.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add Hypothesis property tests for block and mask indexing (`test_block_indexing`, `test_mask_indexing`), along with a `block_indices` strategy in `zarr.testing.strategies`. These extend the existing randomized indexing coverage (basic, orthogonal, and vectorized) to the block and mask selection methods.

src/zarr/testing/strategies.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,51 @@ def orthogonal_indices(
591591
return tuple(zindexer), tuple(np.broadcast_arrays(*npindexer))
592592

593593

594+
@st.composite
595+
def block_indices(
596+
draw: st.DrawFn, *, shape: tuple[int, ...], chunks: tuple[int, ...]
597+
) -> tuple[tuple[int | slice, ...], tuple[slice, ...]]:
598+
"""
599+
Strategy for block-selection indexers over a *regular* chunk grid.
600+
601+
Block indexing addresses whole chunks on the block grid rather than
602+
individual elements. It only supports integers and step-1 slices over the
603+
grid (strided block slices are rejected), so neither newaxis, ellipsis, nor
604+
a step is generated here. The array-space translation below assumes a
605+
regular (uniform) chunk grid, so ``shape`` must be evenly tiled by
606+
``chunks`` up to a possibly-smaller last chunk per dimension. Every
607+
dimension must have at least one chunk (``size >= 1``).
608+
609+
Returns
610+
-------
611+
block_indexer
612+
A tuple of ints / step-1 slices addressing whole chunks, suitable for
613+
``Array.blocks`` / ``Array.get_block_selection`` / ``set_block_selection``.
614+
array_indexer
615+
The equivalent array-space selection (a tuple of slices) for indexing
616+
the corresponding numpy array, used as the comparison oracle.
617+
"""
618+
grid_shape = tuple(-(-s // c) for s, c in zip(shape, chunks, strict=True)) # ceil division
619+
block_indexer: list[int | slice] = []
620+
array_indexer: list[slice] = []
621+
for size, chunk, nchunks in zip(shape, chunks, grid_shape, strict=True):
622+
if draw(st.booleans()):
623+
# a single block, sometimes addressed from the end with a negative index
624+
block = draw(st.integers(min_value=-nchunks, max_value=nchunks - 1))
625+
block_indexer.append(block)
626+
start = (block % nchunks) * chunk
627+
array_indexer.append(slice(start, min(start + chunk, size)))
628+
else:
629+
# a contiguous run of whole blocks (possibly empty). The start must
630+
# reference an existing chunk: block indexing rejects a slice that
631+
# starts at nchunks, unlike numpy which treats arr[len:len] as empty.
632+
start_block = draw(st.integers(min_value=0, max_value=nchunks - 1))
633+
stop_block = draw(st.integers(min_value=start_block, max_value=nchunks))
634+
block_indexer.append(slice(start_block, stop_block))
635+
array_indexer.append(slice(start_block * chunk, min(stop_block * chunk, size)))
636+
return tuple(block_indexer), tuple(array_indexer)
637+
638+
594639
def key_ranges(
595640
keys: SearchStrategy[str] = node_names, max_size: int = sys.maxsize
596641
) -> SearchStrategy[list[tuple[str, RangeByteRequest]]]:

tests/test_properties.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
array_metadata,
2626
arrays,
2727
basic_indices,
28+
block_indices,
2829
complex_rectilinear_arrays,
30+
np_array_and_chunks,
2931
numpy_arrays,
3032
orthogonal_indices,
3133
rectilinear_arrays,
@@ -230,6 +232,63 @@ async def test_vindex(data: st.DataObject) -> None:
230232
# note: async vindex setitem not yet implemented
231233

232234

235+
@settings(deadline=None)
236+
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
237+
@given(data=st.data())
238+
def test_mask_indexing(data: st.DataObject) -> None:
239+
zarray = data.draw(st.one_of(simple_arrays(), rectilinear_arrays()))
240+
nparray = zarray[:]
241+
mask = data.draw(npst.arrays(dtype=np.bool_, shape=st.just(nparray.shape)))
242+
243+
expected = nparray[mask]
244+
245+
# sync get, via both the dedicated method and the vindex interface
246+
assert_array_equal(expected, zarray.get_mask_selection(mask))
247+
assert_array_equal(expected, zarray.vindex[mask])
248+
249+
# sync set, via both interfaces
250+
assume(zarray.shards is None) # GH2834
251+
new_data = data.draw(numpy_arrays(shapes=st.just(expected.shape), dtype=nparray.dtype))
252+
nparray[mask] = new_data
253+
zarray.set_mask_selection(mask, new_data)
254+
assert_array_equal(nparray, zarray[:])
255+
256+
zarray.vindex[mask] = new_data
257+
assert_array_equal(nparray, zarray[:])
258+
259+
260+
@settings(deadline=None)
261+
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
262+
@given(data=st.data())
263+
def test_block_indexing(data: st.DataObject) -> None:
264+
# Block indexing addresses whole chunks on a regular grid; the array-space
265+
# oracle in block_indices() assumes regular, unsharded chunks, so build the
266+
# array directly from a regular chunking rather than drawing one that might
267+
# be rectilinear or sharded.
268+
nparray, chunks = data.draw(
269+
np_array_and_chunks(arrays=numpy_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
270+
)
271+
store = data.draw(stores)
272+
zarray = zarr.create_array(store=store, shape=nparray.shape, chunks=chunks, dtype=nparray.dtype)
273+
zarray[...] = nparray
274+
275+
block_indexer, array_indexer = data.draw(block_indices(shape=nparray.shape, chunks=chunks))
276+
expected = nparray[array_indexer]
277+
278+
# sync get, via both the .blocks interface and the dedicated method
279+
assert_array_equal(expected, zarray.blocks[block_indexer])
280+
assert_array_equal(expected, zarray.get_block_selection(block_indexer))
281+
282+
# sync set, via both interfaces
283+
new_data = data.draw(numpy_arrays(shapes=st.just(expected.shape), dtype=nparray.dtype))
284+
nparray[array_indexer] = new_data
285+
zarray.blocks[block_indexer] = new_data
286+
assert_array_equal(nparray, zarray[:])
287+
288+
zarray.set_block_selection(block_indexer, new_data)
289+
assert_array_equal(nparray, zarray[:])
290+
291+
233292
@given(store=stores, meta=array_metadata()) # type: ignore[misc]
234293
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
235294
async def test_roundtrip_array_metadata_from_store(

0 commit comments

Comments
 (0)