Skip to content

Commit cea866f

Browse files
authored
Merge branch 'main' into zarr-python-use-sentinel-type
2 parents e443b56 + 96a62b5 commit cea866f

3 files changed

Lines changed: 123 additions & 0 deletions

File tree

changes/4054.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add Hypothesis property tests for block and mask indexing (`test_block_indexing`, `test_mask_indexing`), along with a `block_indices` strategy in `zarr.testing.strategies`. These extend the existing randomized indexing coverage (basic, orthogonal, and vectorized) to the block and mask selection methods.

src/zarr/testing/strategies.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,67 @@ def orthogonal_indices(
591591
return tuple(zindexer), tuple(np.broadcast_arrays(*npindexer))
592592

593593

594+
@st.composite
595+
def block_indices(
596+
draw: st.DrawFn, *, chunk_grid_shape: tuple[int, ...], chunks: tuple[int, ...]
597+
) -> tuple[tuple[int | slice, ...], tuple[slice, ...]]:
598+
"""
599+
Strategy for block-selection indexers over a *regular* chunk grid.
600+
601+
Block indexing is basic indexing applied to the block grid (the grid of
602+
chunks), so each axis is drawn with ``basic_indices`` over that axis's chunk
603+
count from ``chunk_grid_shape`` (e.g. ``Array.cdata_shape``), mirroring how
604+
``orthogonal_indices`` reuses ``basic_indices`` per axis. Block indexing only
605+
supports integers and step-1 slices whose start references an existing chunk,
606+
so strided slices and slices starting at the grid edge are filtered out. The
607+
array-space translation assumes a regular (uniform) chunk grid; an over-long
608+
stop into a smaller last chunk is left for numpy to clamp when the oracle is
609+
applied.
610+
611+
Returns
612+
-------
613+
block_indexer
614+
A per-axis tuple of ints / step-1 slices addressing whole chunks,
615+
suitable for ``Array.blocks`` / ``get_block_selection`` / ``set_block_selection``.
616+
array_indexer
617+
The equivalent array-space selection (a tuple of slices) for indexing
618+
the corresponding numpy array, used as the comparison oracle.
619+
"""
620+
621+
def supported(nchunks: int) -> Callable[[tuple[Any, ...]], bool]:
622+
# Block indexing only accepts step-1 slices whose start references an
623+
# existing chunk (a slice starting at nchunks raises, unlike numpy).
624+
def predicate(value: tuple[Any, ...]) -> bool:
625+
dim_sel = value[0]
626+
if isinstance(dim_sel, slice):
627+
if dim_sel.step not in (None, 1):
628+
return False
629+
start = dim_sel.start or 0
630+
return 0 <= (start + nchunks if start < 0 else start) < nchunks
631+
return True
632+
633+
return predicate
634+
635+
block_indexer: list[int | slice] = []
636+
array_indexer: list[slice] = []
637+
for chunk, nchunks in zip(chunks, chunk_grid_shape, strict=True):
638+
(dim_sel,) = draw(
639+
basic_indices(min_dims=1, shape=(nchunks,), allow_ellipsis=False)
640+
# normalize bare ints / slices to a 1-tuple, skip the empty tuple
641+
.map(lambda x: (x,) if not isinstance(x, tuple) else x)
642+
.filter(bool)
643+
.filter(supported(nchunks))
644+
)
645+
block_indexer.append(dim_sel)
646+
if isinstance(dim_sel, slice):
647+
start, stop, _ = dim_sel.indices(nchunks)
648+
array_indexer.append(slice(start * chunk, stop * chunk))
649+
else:
650+
block = dim_sel % nchunks
651+
array_indexer.append(slice(block * chunk, (block + 1) * chunk))
652+
return tuple(block_indexer), tuple(array_indexer)
653+
654+
594655
def key_ranges(
595656
keys: SearchStrategy[str] = node_names, max_size: int = sys.maxsize
596657
) -> SearchStrategy[list[tuple[str, RangeByteRequest]]]:

tests/test_properties.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
array_metadata,
2626
arrays,
2727
basic_indices,
28+
block_indices,
2829
complex_rectilinear_arrays,
30+
np_array_and_chunks,
2931
numpy_arrays,
3032
orthogonal_indices,
3133
rectilinear_arrays,
@@ -230,6 +232,65 @@ async def test_vindex(data: st.DataObject) -> None:
230232
# note: async vindex setitem not yet implemented
231233

232234

235+
@settings(deadline=None)
236+
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
237+
@given(data=st.data())
238+
def test_mask_indexing(data: st.DataObject) -> None:
239+
zarray = data.draw(st.one_of(simple_arrays(), rectilinear_arrays()))
240+
nparray = zarray[:]
241+
mask = data.draw(npst.arrays(dtype=np.bool_, shape=st.just(nparray.shape)))
242+
243+
expected = nparray[mask]
244+
245+
# sync get, via both the dedicated method and the vindex interface
246+
assert_array_equal(expected, zarray.get_mask_selection(mask))
247+
assert_array_equal(expected, zarray.vindex[mask])
248+
249+
# sync set, via both interfaces
250+
assume(zarray.shards is None) # GH2834
251+
new_data = data.draw(numpy_arrays(shapes=st.just(expected.shape), dtype=nparray.dtype))
252+
nparray[mask] = new_data
253+
zarray.set_mask_selection(mask, new_data)
254+
assert_array_equal(nparray, zarray[:])
255+
256+
zarray.vindex[mask] = new_data
257+
assert_array_equal(nparray, zarray[:])
258+
259+
260+
@settings(deadline=None)
261+
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
262+
@given(data=st.data())
263+
def test_block_indexing(data: st.DataObject) -> None:
264+
# Block indexing addresses whole chunks on a regular grid; the array-space
265+
# oracle in block_indices() assumes regular, unsharded chunks, so build the
266+
# array directly from a regular chunking rather than drawing one that might
267+
# be rectilinear or sharded.
268+
nparray, chunks = data.draw(
269+
np_array_and_chunks(arrays=numpy_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
270+
)
271+
store = data.draw(stores)
272+
zarray = zarr.create_array(store=store, shape=nparray.shape, chunks=chunks, dtype=nparray.dtype)
273+
zarray[...] = nparray
274+
275+
block_indexer, array_indexer = data.draw(
276+
block_indices(chunk_grid_shape=zarray.cdata_shape, chunks=chunks)
277+
)
278+
expected = nparray[array_indexer]
279+
280+
# sync get, via both the .blocks interface and the dedicated method
281+
assert_array_equal(expected, zarray.blocks[block_indexer])
282+
assert_array_equal(expected, zarray.get_block_selection(block_indexer))
283+
284+
# sync set, via both interfaces
285+
new_data = data.draw(numpy_arrays(shapes=st.just(expected.shape), dtype=nparray.dtype))
286+
nparray[array_indexer] = new_data
287+
zarray.blocks[block_indexer] = new_data
288+
assert_array_equal(nparray, zarray[:])
289+
290+
zarray.set_block_selection(block_indexer, new_data)
291+
assert_array_equal(nparray, zarray[:])
292+
293+
233294
@given(store=stores, meta=array_metadata()) # type: ignore[misc]
234295
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
235296
async def test_roundtrip_array_metadata_from_store(

0 commit comments

Comments
 (0)