From ebbfbe068ee0d1d8428e505df1c84a0be9570b48 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sat, 3 Aug 2024 11:36:13 +0200 Subject: [PATCH 01/22] implement store.list_prefix and store._set_dict --- src/zarr/abc/store.py | 10 +++- src/zarr/store/local.py | 4 -- src/zarr/store/memory.py | 2 +- src/zarr/store/remote.py | 9 +++- src/zarr/sync.py | 17 ++++++ src/zarr/testing/store.py | 85 ++++++++++++++++++++---------- tests/v3/test_store/test_remote.py | 36 ++++++------- 7 files changed, 107 insertions(+), 56 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 449816209b..95d12943b9 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Mapping from typing import Any, NamedTuple, Protocol, runtime_checkable from typing_extensions import Self @@ -221,6 +221,14 @@ def close(self) -> None: self._is_open = False pass + async def _set_dict(self, dict: Mapping[str, Buffer]) -> None: + """ + Insert objects into storage as defined by a prefix: value mapping. + """ + for key, value in dict.items(): + await self.set(key, value) + return None + @runtime_checkable class ByteGetter(Protocol): diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index 25fd9fc13a..cc6ba38f21 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -193,10 +193,6 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: ------- AsyncGenerator[str, None] """ - for p in (self.root / prefix).rglob("*"): - if p.is_file(): - yield str(p) - to_strip = str(self.root) + "/" for p in (self.root / prefix).rglob("*"): if p.is_file(): diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index dd3e52e703..c3a61e8e51 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -101,7 +101,7 @@ async def list(self) -> AsyncGenerator[str, None]: async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: for key in self._store_dict: if key.startswith(prefix): - yield key + yield key.removeprefix(prefix) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: if prefix.endswith("/"): diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index c742d9e567..87b8fe6573 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -205,5 +205,10 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: yield onefile async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - for onefile in await self._fs._ls(prefix, detail=False): - yield onefile + if prefix == "": + find_str = "/".join([self.path, prefix]) + else: + find_str = "/".join([self.path, prefix]) + + for onefile in await self._fs._find(find_str): + yield onefile.removeprefix(find_str) diff --git a/src/zarr/sync.py b/src/zarr/sync.py index 8af14f602e..446ffd43e2 100644 --- a/src/zarr/sync.py +++ b/src/zarr/sync.py @@ -114,6 +114,23 @@ def _get_loop() -> asyncio.AbstractEventLoop: return loop[0] +async def _collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]: + """ + Collect an entire async iterator into a tuple + """ + result = [] + async for x in data: + result.append(x) + return tuple(result) + + +def collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]: + """ + Synchronously collect an entire async iterator into a tuple. + """ + return sync(_collect_aiterator(data)) + + class SyncMixin: def _sync(self, coroutine: Coroutine[Any, Any, T]) -> T: # TODO: refactor this to to take *args and **kwargs and pass those to the method diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 4fdf497a68..ba37dda625 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -5,6 +5,7 @@ from zarr.abc.store import AccessMode, Store from zarr.buffer import Buffer, default_buffer_prototype from zarr.store.utils import _normalize_interval_index +from zarr.sync import _collect_aiterator from zarr.testing.utils import assert_bytes_equal S = TypeVar("S", bound=Store) @@ -103,6 +104,18 @@ async def test_set(self, store: S, key: str, data: bytes) -> None: observed = self.get(store, key) assert_bytes_equal(observed, data_buf) + async def test_set_dict(self, store: S) -> None: + """ + Test that a dict of key : value pairs can be inserted into the store via the + `_set_dict` method. + """ + keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"] + data_buf = [Buffer.from_bytes(k.encode()) for k in keys] + store_dict = dict(zip(keys, data_buf, strict=True)) + await store._set_dict(store_dict) + for k, v in store_dict.items(): + assert self.get(store, k).to_bytes() == v.to_bytes() + @pytest.mark.parametrize( "key_ranges", ( @@ -165,37 +178,55 @@ async def test_clear(self, store: S) -> None: assert await store.empty() async def test_list(self, store: S) -> None: - assert [k async for k in store.list()] == [] - await store.set("foo/zarr.json", Buffer.from_bytes(b"bar")) - keys = [k async for k in store.list()] - assert keys == ["foo/zarr.json"], keys - - expected = ["foo/zarr.json"] - for i in range(10): - key = f"foo/c/{i}" - expected.append(key) - await store.set( - f"foo/c/{i}", Buffer.from_bytes(i.to_bytes(length=3, byteorder="little")) - ) + assert await _collect_aiterator(store.list()) == () + prefix = "foo" + data = Buffer.from_bytes(b"") + store_dict = { + prefix + "/zarr.json": data, + **{prefix + f"/c/{idx}": data for idx in range(10)}, + } + await store._set_dict(store_dict) + expected_sorted = sorted(store_dict.keys()) + observed = await _collect_aiterator(store.list()) + observed_sorted = sorted(observed) + assert observed_sorted == expected_sorted - @pytest.mark.xfail async def test_list_prefix(self, store: S) -> None: - # TODO: we currently don't use list_prefix anywhere - raise NotImplementedError + """ + Test that the `list_prefix` method works as intended. Given a prefix, it should return + all the keys in storage that start with this prefix. Keys should be returned with the shared + prefix removed. + """ + prefixes = ("", "a/", "a/b/", "a/b/c/") + data = Buffer.from_bytes(b"") + fname = "zarr.json" + store_dict = {p + fname: data for p in prefixes} + await store._set_dict(store_dict) + for p in prefixes: + observed = tuple(sorted(await _collect_aiterator(store.list_prefix(p)))) + expected: tuple[str, ...] = () + for k in store_dict.keys(): + if k.startswith(p): + expected += (k.removeprefix(p),) + expected = tuple(sorted(expected)) + assert observed == expected async def test_list_dir(self, store: S) -> None: - out = [k async for k in store.list_dir("")] - assert out == [] - assert [k async for k in store.list_dir("foo")] == [] - await store.set("foo/zarr.json", Buffer.from_bytes(b"bar")) - await store.set("foo/c/1", Buffer.from_bytes(b"\x01")) + root = "foo" + store_dict = { + root + "/zarr.json": Buffer.from_bytes(b"bar"), + root + "/c/1": Buffer.from_bytes(b"\x01"), + } + + assert await _collect_aiterator(store.list_dir("")) == () + assert await _collect_aiterator(store.list_dir(root)) == () + + await store._set_dict(store_dict) - keys_expected = ["zarr.json", "c"] - keys_observed = [k async for k in store.list_dir("foo")] + keys_observed = await _collect_aiterator(store.list_dir(root)) + keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()} - assert len(keys_observed) == len(keys_expected), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed + assert sorted(keys_observed) == sorted(keys_expected) - keys_observed = [k async for k in store.list_dir("foo/")] - assert len(keys_expected) == len(keys_observed), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed + keys_observed = await _collect_aiterator(store.list_dir(root + "/")) + assert sorted(keys_expected) == sorted(keys_observed) diff --git a/tests/v3/test_store/test_remote.py b/tests/v3/test_store/test_remote.py index be9fa5ef67..14a181d7b6 100644 --- a/tests/v3/test_store/test_remote.py +++ b/tests/v3/test_store/test_remote.py @@ -1,12 +1,18 @@ +from __future__ import annotations + import os +from collections.abc import Generator import fsspec import pytest +from botocore.client import BaseClient +from botocore.session import Session +from s3fs import S3FileSystem from upath import UPath from zarr.buffer import Buffer, default_buffer_prototype from zarr.store import RemoteStore -from zarr.sync import sync +from zarr.sync import _collect_aiterator, sync from zarr.testing.store import StoreTests s3fs = pytest.importorskip("s3fs") @@ -22,7 +28,7 @@ @pytest.fixture(scope="module") -def s3_base(): +def s3_base() -> Generator[None, None, None]: # writable local S3 system # This fixture is module-scoped, meaning that we can reuse the MotoServer across all tests @@ -37,16 +43,14 @@ def s3_base(): server.stop() -def get_boto3_client(): - from botocore.session import Session - +def get_boto3_client() -> BaseClient: # NB: we use the sync botocore client for setup session = Session() return session.create_client("s3", endpoint_url=endpoint_url) @pytest.fixture(autouse=True, scope="function") -def s3(s3_base): +def s3(s3_base: Generator[None, None, None]) -> Generator[S3FileSystem, None, None]: """ Quoting Martin Durant: pytest-asyncio creates a new event loop for each async test. @@ -71,21 +75,11 @@ def s3(s3_base): sync(session.close()) -# ### end from s3fs ### # - - -async def alist(it): - out = [] - async for a in it: - out.append(a) - return out - - -async def test_basic(): +async def test_basic() -> None: store = await RemoteStore.open( f"s3://{test_bucket_name}", mode="w", endpoint_url=endpoint_url, anon=False ) - assert not await alist(store.list()) + assert await _collect_aiterator(store.list()) == () assert not await store.exists("foo") data = b"hello" await store.set("foo", Buffer.from_bytes(data)) @@ -101,7 +95,7 @@ class TestRemoteStoreS3(StoreTests[RemoteStore]): store_cls = RemoteStore @pytest.fixture(scope="function", params=("use_upath", "use_str")) - def store_kwargs(self, request) -> dict[str, str | bool]: + def store_kwargs(self, request: pytest.FixtureRequest) -> dict[str, str | bool | UPath]: # type: ignore url = f"s3://{test_bucket_name}" anon = False mode = "r+" @@ -113,8 +107,8 @@ def store_kwargs(self, request) -> dict[str, str | bool]: raise AssertionError @pytest.fixture(scope="function") - def store(self, store_kwargs: dict[str, str | bool]) -> RemoteStore: - url = store_kwargs["url"] + async def store(self, store_kwargs: dict[str, str | bool | UPath]) -> RemoteStore: + url: str | UPath = store_kwargs["url"] mode = store_kwargs["mode"] if isinstance(url, UPath): out = self.store_cls(url=url, mode=mode) From da6083e6761ff2cc786a2651bdf5722f65f6636f Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sat, 3 Aug 2024 11:47:00 +0200 Subject: [PATCH 02/22] simplify string handling --- src/zarr/store/local.py | 2 +- src/zarr/store/remote.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index cc6ba38f21..fe18213435 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -196,7 +196,7 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: to_strip = str(self.root) + "/" for p in (self.root / prefix).rglob("*"): if p.is_file(): - yield str(p).replace(to_strip, "") + yield str(p).removeprefix(to_strip) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: """ diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 87b8fe6573..0b9e3bb8ce 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -205,10 +205,6 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: yield onefile async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - if prefix == "": - find_str = "/".join([self.path, prefix]) - else: - find_str = "/".join([self.path, prefix]) - + find_str = "/".join([self.path, prefix]) for onefile in await self._fs._find(find_str): yield onefile.removeprefix(find_str) From dc5fe4740d771afcb1abb87e92b2f6af17c9d01b Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 2 Aug 2024 16:01:33 +0200 Subject: [PATCH 03/22] add nchunks_initialized, and necessary additions for it --- pyproject.toml | 1 + src/zarr/array.py | 140 ++++++++++++++++++++++++++++++++++++-- src/zarr/indexing.py | 28 +++++++- src/zarr/metadata.py | 10 +++ src/zarr/store/remote.py | 6 +- src/zarr/testing/store.py | 49 +++++++++---- tests/v3/conftest.py | 9 +++ tests/v3/test_array.py | 34 +++++++-- 8 files changed, 246 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f1be6725b6..b43f1f74a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,4 +220,5 @@ filterwarnings = [ "error:::zarr.*", "ignore:PY_SSIZE_T_CLEAN will be required.*:DeprecationWarning", "ignore:The loop argument is deprecated since Python 3.8.*:DeprecationWarning", + "ignore:.*is transitional and will be removed.*:DeprecationWarning", ] diff --git a/src/zarr/array.py b/src/zarr/array.py index e41118805e..1d1600da2b 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -10,12 +10,13 @@ # Questions to consider: # 1. Was splitting the array into two classes really necessary? from asyncio import gather -from collections.abc import Iterable +from collections.abc import Iterable, Iterator from dataclasses import dataclass, field, replace from typing import Any, Literal, cast import numpy as np import numpy.typing as npt +from typing_extensions import deprecated from zarr.abc.codec import Codec, CodecPipeline from zarr.abc.store import set_or_delete @@ -52,11 +53,13 @@ OrthogonalSelection, Selection, VIndex, + ceildiv, check_fields, check_no_multi_fields, is_pure_fancy_indexing, is_pure_orthogonal_indexing, is_scalar, + iter_grid, pop_fields, ) from zarr.metadata import ArrayMetadata, ArrayV2Metadata, ArrayV3Metadata @@ -65,7 +68,7 @@ from zarr.store.core import ( ensure_no_existing_node, ) -from zarr.sync import sync +from zarr.sync import collect_aiterator, sync def parse_array_metadata(data: Any) -> ArrayV2Metadata | ArrayV3Metadata: @@ -393,10 +396,12 @@ def shape(self) -> ChunkCoords: def chunks(self) -> ChunkCoords: if isinstance(self.metadata.chunk_grid, RegularChunkGrid): return self.metadata.chunk_grid.chunk_shape - else: - raise ValueError( - f"chunk attribute is only available for RegularChunkGrid, this array has a {self.metadata.chunk_grid}" - ) + + msg = ( + f"The `chunks` attribute is only defined for arrays using `RegularChunkGrid`." + f"This array has a {self.metadata.chunk_grid} instead." + ) + raise NotImplementedError(msg) @property def size(self) -> int: @@ -437,6 +442,59 @@ def basename(self) -> str | None: return self.name.split("/")[-1] return None + @property + @deprecated( + "cdata_shape is transitional and will be removed in an early zarr-python v3 release." + ) + def cdata_shape(self) -> ChunkCoords: + """ + The shape of the chunk grid for this array. + """ + return tuple(ceildiv(s, c) for s, c in zip(self.shape, self.chunks, strict=False)) + + @property + @deprecated("nchunks is transitional and will be removed in an early zarr-python v3 release.") + def nchunks(self) -> int: + """ + The number of chunks in the stored representation of this array. + """ + return product(self.cdata_shape) + + @property + def _iter_chunks(self) -> Iterator[ChunkCoords]: + """ + Produce an iterator over the coordinates of each chunk, in chunk grid space. + """ + return iter_grid(self.cdata_shape) + + @property + def _iter_chunk_keys(self) -> Iterator[str]: + """ + Return an iterator over the keys of each chunk. + """ + for k in self._iter_chunks: + yield self.metadata.encode_chunk_key(k) + + @property + def _iter_chunk_regions(self) -> Iterator[tuple[slice, ...]]: + """ + Iterate over the regions spanned by each chunk. + """ + for cgrid_position in self._iter_chunks: + out: tuple[slice, ...] = () + for c_pos, c_shape in zip(cgrid_position, self.chunks, strict=False): + start = c_pos * c_shape + stop = start + c_shape + out += (slice(start, stop, 1),) + yield out + + @property + def nbytes(self) -> int: + """ + The number of bytes that can be stored in this array. + """ + return self.nchunks * self.dtype.itemsize + async def _get_selection( self, indexer: Indexer, @@ -735,6 +793,52 @@ def read_only(self) -> bool: def fill_value(self) -> Any: return self.metadata.fill_value + @property + @deprecated( + "cdata_shape is transitional and will be removed in an early zarr-python v3 release." + ) + def cdata_shape(self) -> ChunkCoords: + """ + The shape of the chunk grid for this array. + """ + return tuple(ceildiv(s, c) for s, c in zip(self.shape, self.chunks, strict=False)) + + @property + @deprecated("nchunks is transitional and will be removed in an early zarr-python v3 release.") + def nchunks(self) -> int: + """ + The number of chunks in the stored representation of this array. + """ + return self._async_array.nchunks + + @property + def _iter_chunks(self) -> Iterator[ChunkCoords]: + """ + Produce an iterator over the coordinates of each chunk, in chunk grid space. + """ + yield from self._async_array._iter_chunks + + @property + def nbytes(self) -> int: + """ + The number of bytes that can be stored in this array. + """ + return self._async_array.nbytes + + @property + def _iter_chunk_keys(self) -> Iterator[str]: + """ + Return an iterator over the keys of each chunk. + """ + yield from self._async_array._iter_chunk_keys + + @property + def _iter_chunk_regions(self) -> Iterator[tuple[slice, ...]]: + """ + Iterate over the regions spanned by each chunk. + """ + yield from self._async_array._iter_chunk_regions + def __array__( self, dtype: npt.DTypeLike | None = None, copy: bool | None = None ) -> NDArrayLike: @@ -2056,3 +2160,27 @@ def info(self) -> None: return sync( self._async_array.info(), ) + + +@deprecated( + "nchunks_initialized is transitional and will be removed in an early zarr-python v3 release." +) +def nchunks_initialized(array: Array) -> int: + return len(chunks_initialized(array)) + + +def chunks_initialized(array: Array) -> tuple[str, ...]: + """ + Return the keys of all the chunks that exist in storage. + """ + # todo: make this compose with the underlying async iterator + store_contents = list( + collect_aiterator(array.store_path.store.list_prefix(prefix=array.store_path.path)) + ) + out: list[str] = [] + + for chunk_key in array._iter_chunk_keys: + if chunk_key in store_contents: + out.append(chunk_key) + + return tuple(out) diff --git a/src/zarr/indexing.py b/src/zarr/indexing.py index 6987f69c11..91ab571b36 100644 --- a/src/zarr/indexing.py +++ b/src/zarr/indexing.py @@ -4,14 +4,13 @@ import math import numbers import operator -from collections.abc import Iterator, Sequence +from collections.abc import Iterable, Iterator, Sequence from dataclasses import dataclass from enum import Enum from functools import reduce from types import EllipsisType from typing import ( TYPE_CHECKING, - Any, NamedTuple, Protocol, TypeGuard, @@ -27,6 +26,8 @@ from zarr.common import ChunkCoords, product if TYPE_CHECKING: + from typing import Any + from zarr.array import Array from zarr.chunk_grids import ChunkGrid @@ -86,6 +87,29 @@ def ceildiv(a: float, b: float) -> int: return math.ceil(a / b) +def iter_grid(shape: Iterable[int]) -> Iterator[ChunkCoords]: + """ + Iterate over the elements of grid. + + Takes a grid shape expressed as an iterable of ints and + yields tuples bounded by that grid shape in lexicographic order. + + Examples + -------- + >>> tuple(iter_grid((1,))) + ((0,),) + + >>> tuple(iter_grid((2,3))) + ((0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)) + + Parameters + ---------- + shape: Iterable[int] + The shape of the grid to iterate over. + """ + yield from itertools.product(*(map(range, shape))) + + def is_integer(x: Any) -> TypeGuard[int]: """True if x is an integer (both pure Python or NumPy).""" return isinstance(x, numbers.Integral) and not is_bool(x) diff --git a/src/zarr/metadata.py b/src/zarr/metadata.py index e801a6f966..0d1a9788a1 100644 --- a/src/zarr/metadata.py +++ b/src/zarr/metadata.py @@ -141,6 +141,10 @@ def get_chunk_spec( def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str: pass + @abstractmethod + def decode_chunk_key(self, key: str) -> ChunkCoords: + pass + @abstractmethod def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: pass @@ -252,6 +256,9 @@ def get_chunk_spec( def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str: return self.chunk_key_encoding.encode_chunk_key(chunk_coords) + def decode_chunk_key(self, key: str) -> ChunkCoords: + return self.chunk_key_encoding.decode_chunk_key(key) + def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: def _json_convert(o: Any) -> Any: if isinstance(o, np.dtype): @@ -445,6 +452,9 @@ def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str: chunk_identifier = self.dimension_separator.join(map(str, chunk_coords)) return "0" if chunk_identifier == "" else chunk_identifier + def decode_chunk_key(self, key: str) -> ChunkCoords: + return tuple(map(int, key.split(self.dimension_separator))) + def update_shape(self, shape: ChunkCoords) -> Self: return replace(self, shape=shape) diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 0b9e3bb8ce..0fd013cb4b 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -205,6 +205,6 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: yield onefile async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - find_str = "/".join([self.path, prefix]) - for onefile in await self._fs._find(find_str): - yield onefile.removeprefix(find_str) + glob_str = f"{self.path}/{prefix.rstrip('/')}**" + for onefile in await self._fs._glob(glob_str): + yield onefile diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index ba37dda625..8662abaeb7 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -6,6 +6,7 @@ from zarr.buffer import Buffer, default_buffer_prototype from zarr.store.utils import _normalize_interval_index from zarr.sync import _collect_aiterator +from zarr.sync import _collect_aiterator from zarr.testing.utils import assert_bytes_equal S = TypeVar("S", bound=Store) @@ -104,17 +105,17 @@ async def test_set(self, store: S, key: str, data: bytes) -> None: observed = self.get(store, key) assert_bytes_equal(observed, data_buf) - async def test_set_dict(self, store: S) -> None: + async def set_set_dict(self, store: S) -> None: """ Test that a dict of key : value pairs can be inserted into the store via the `_set_dict` method. """ keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"] data_buf = [Buffer.from_bytes(k.encode()) for k in keys] - store_dict = dict(zip(keys, data_buf, strict=True)) + store_dict = dict(zip(keys, data_buf, strict=False)) await store._set_dict(store_dict) for k, v in store_dict.items(): - assert self.get(store, k).to_bytes() == v.to_bytes() + assert self.get(store, k) == v @pytest.mark.parametrize( "key_ranges", @@ -191,12 +192,20 @@ async def test_list(self, store: S) -> None: observed_sorted = sorted(observed) assert observed_sorted == expected_sorted + assert await _collect_aiterator(store.list()) == () + prefix = "foo" + data = Buffer.from_bytes(b"") + store_dict = { + prefix + "/zarr.json": data, + **{prefix + f"/c/{idx}": data for idx in range(10)}, + } + await store._set_dict(store_dict) + expected_sorted = sorted(store_dict.keys()) + observed = await _collect_aiterator(store.list()) + observed_sorted = sorted(observed) + assert observed_sorted == expected_sorted + async def test_list_prefix(self, store: S) -> None: - """ - Test that the `list_prefix` method works as intended. Given a prefix, it should return - all the keys in storage that start with this prefix. Keys should be returned with the shared - prefix removed. - """ prefixes = ("", "a/", "a/b/", "a/b/c/") data = Buffer.from_bytes(b"") fname = "zarr.json" @@ -204,11 +213,7 @@ async def test_list_prefix(self, store: S) -> None: await store._set_dict(store_dict) for p in prefixes: observed = tuple(sorted(await _collect_aiterator(store.list_prefix(p)))) - expected: tuple[str, ...] = () - for k in store_dict.keys(): - if k.startswith(p): - expected += (k.removeprefix(p),) - expected = tuple(sorted(expected)) + expected = tuple(sorted(filter(lambda v: v.startswith(p), store_dict.keys()))) assert observed == expected async def test_list_dir(self, store: S) -> None: @@ -230,3 +235,21 @@ async def test_list_dir(self, store: S) -> None: keys_observed = await _collect_aiterator(store.list_dir(root + "/")) assert sorted(keys_expected) == sorted(keys_observed) + root = "foo" + store_dict = { + root + "/zarr.json": Buffer.from_bytes(b"bar"), + root + "/c/1": Buffer.from_bytes(b"\x01"), + } + + assert await _collect_aiterator(store.list_dir("")) == () + assert await _collect_aiterator(store.list_dir(root)) == () + + await store._set_dict(store_dict) + + keys_observed = await _collect_aiterator(store.list_dir(root)) + keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()} + + assert sorted(keys_observed) == sorted(keys_expected) + + keys_observed = await _collect_aiterator(store.list_dir(root + "/")) + assert sorted(keys_expected) == sorted(keys_observed) diff --git a/tests/v3/conftest.py b/tests/v3/conftest.py index 267fcc85bd..ec5e1b22ec 100644 --- a/tests/v3/conftest.py +++ b/tests/v3/conftest.py @@ -119,3 +119,12 @@ def array_fixture(request: pytest.FixtureRequest) -> np.ndarray: .reshape(array_request.shape, order=array_request.order) .astype(array_request.dtype) ) + + +@pytest.fixture(params=[2, 3]) +def zarr_format(request: pytest.FixtureRequest) -> ZarrFormat: + if request.param == 2: + return 2 + if request.param == 3: + return 3 + raise ValueError("Invalid parameterization of this test fixture.") diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 9fd135ad5c..0eeeeece0d 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -3,16 +3,16 @@ import numpy as np import pytest -from zarr.array import Array +from zarr.array import Array, nchunks_initialized from zarr.common import ZarrFormat from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.group import Group from zarr.store import LocalStore, MemoryStore from zarr.store.core import StorePath +from zarr.sync import sync -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) +@pytest.mark.parametrize("store", ("local", "memory"), indirect=True) @pytest.mark.parametrize("exists_ok", [True, False]) @pytest.mark.parametrize("extant_node", ["array", "group"]) def test_array_creation_existing_node( @@ -60,8 +60,7 @@ def test_array_creation_existing_node( ) -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) +@pytest.mark.parametrize("store", ("local", "memory"), indirect=True) def test_array_name_properties_no_group( store: LocalStore | MemoryStore, zarr_format: ZarrFormat ) -> None: @@ -71,8 +70,7 @@ def test_array_name_properties_no_group( assert arr.basename is None -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) +@pytest.mark.parametrize("store", ("local", "memory"), indirect=True) def test_array_name_properties_with_group( store: LocalStore | MemoryStore, zarr_format: ZarrFormat ) -> None: @@ -136,3 +134,25 @@ def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str assert arr.fill_value == np.dtype(dtype_str).type(fill_value) assert arr.fill_value.dtype == arr.dtype + + +@pytest.mark.parametrize("store", ("local", "memory"), indirect=True) +def test_nchunks_initialized(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: + """ + Test that the nchunks_initialized function accurately reports the number of initialized chunks + in storage + """ + num_chunks = 10 + array = Array.create( + store=store, shape=(num_chunks,), chunks=(1,), dtype="uint8", zarr_format=zarr_format + ) + assert array.nchunks == num_chunks + assert nchunks_initialized(array) == 0 + + for idx, region in enumerate(array._iter_chunk_regions): + array[region] = 1 + assert nchunks_initialized(array) == idx + 1 + + for idx, key in enumerate(array._iter_chunk_keys): + sync((array.store_path / key).delete()) + assert nchunks_initialized(array) == array.nchunks - (idx + 1) From b694b6ed902589cb37f70bce8b3bca304c0f9088 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 2 Aug 2024 16:33:16 +0200 Subject: [PATCH 04/22] rename _iter_chunks to _iter_chunk_coords --- src/zarr/array.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/zarr/array.py b/src/zarr/array.py index 1d1600da2b..6043e47cee 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -461,7 +461,7 @@ def nchunks(self) -> int: return product(self.cdata_shape) @property - def _iter_chunks(self) -> Iterator[ChunkCoords]: + def _iter_chunk_coords(self) -> Iterator[ChunkCoords]: """ Produce an iterator over the coordinates of each chunk, in chunk grid space. """ @@ -472,7 +472,7 @@ def _iter_chunk_keys(self) -> Iterator[str]: """ Return an iterator over the keys of each chunk. """ - for k in self._iter_chunks: + for k in self._iter_chunk_coords: yield self.metadata.encode_chunk_key(k) @property @@ -480,7 +480,7 @@ def _iter_chunk_regions(self) -> Iterator[tuple[slice, ...]]: """ Iterate over the regions spanned by each chunk. """ - for cgrid_position in self._iter_chunks: + for cgrid_position in self._iter_chunk_coords: out: tuple[slice, ...] = () for c_pos, c_shape in zip(cgrid_position, self.chunks, strict=False): start = c_pos * c_shape @@ -816,7 +816,7 @@ def _iter_chunks(self) -> Iterator[ChunkCoords]: """ Produce an iterator over the coordinates of each chunk, in chunk grid space. """ - yield from self._async_array._iter_chunks + yield from self._async_array._iter_chunk_coords @property def nbytes(self) -> int: From 6a27ca850b1834dfc346c0398c978f8279aefc38 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sat, 3 Aug 2024 12:38:26 +0200 Subject: [PATCH 05/22] fix test name --- src/zarr/testing/store.py | 49 +++++++++++---------------------------- 1 file changed, 13 insertions(+), 36 deletions(-) diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 8662abaeb7..ba37dda625 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -6,7 +6,6 @@ from zarr.buffer import Buffer, default_buffer_prototype from zarr.store.utils import _normalize_interval_index from zarr.sync import _collect_aiterator -from zarr.sync import _collect_aiterator from zarr.testing.utils import assert_bytes_equal S = TypeVar("S", bound=Store) @@ -105,17 +104,17 @@ async def test_set(self, store: S, key: str, data: bytes) -> None: observed = self.get(store, key) assert_bytes_equal(observed, data_buf) - async def set_set_dict(self, store: S) -> None: + async def test_set_dict(self, store: S) -> None: """ Test that a dict of key : value pairs can be inserted into the store via the `_set_dict` method. """ keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"] data_buf = [Buffer.from_bytes(k.encode()) for k in keys] - store_dict = dict(zip(keys, data_buf, strict=False)) + store_dict = dict(zip(keys, data_buf, strict=True)) await store._set_dict(store_dict) for k, v in store_dict.items(): - assert self.get(store, k) == v + assert self.get(store, k).to_bytes() == v.to_bytes() @pytest.mark.parametrize( "key_ranges", @@ -192,20 +191,12 @@ async def test_list(self, store: S) -> None: observed_sorted = sorted(observed) assert observed_sorted == expected_sorted - assert await _collect_aiterator(store.list()) == () - prefix = "foo" - data = Buffer.from_bytes(b"") - store_dict = { - prefix + "/zarr.json": data, - **{prefix + f"/c/{idx}": data for idx in range(10)}, - } - await store._set_dict(store_dict) - expected_sorted = sorted(store_dict.keys()) - observed = await _collect_aiterator(store.list()) - observed_sorted = sorted(observed) - assert observed_sorted == expected_sorted - async def test_list_prefix(self, store: S) -> None: + """ + Test that the `list_prefix` method works as intended. Given a prefix, it should return + all the keys in storage that start with this prefix. Keys should be returned with the shared + prefix removed. + """ prefixes = ("", "a/", "a/b/", "a/b/c/") data = Buffer.from_bytes(b"") fname = "zarr.json" @@ -213,7 +204,11 @@ async def test_list_prefix(self, store: S) -> None: await store._set_dict(store_dict) for p in prefixes: observed = tuple(sorted(await _collect_aiterator(store.list_prefix(p)))) - expected = tuple(sorted(filter(lambda v: v.startswith(p), store_dict.keys()))) + expected: tuple[str, ...] = () + for k in store_dict.keys(): + if k.startswith(p): + expected += (k.removeprefix(p),) + expected = tuple(sorted(expected)) assert observed == expected async def test_list_dir(self, store: S) -> None: @@ -235,21 +230,3 @@ async def test_list_dir(self, store: S) -> None: keys_observed = await _collect_aiterator(store.list_dir(root + "/")) assert sorted(keys_expected) == sorted(keys_observed) - root = "foo" - store_dict = { - root + "/zarr.json": Buffer.from_bytes(b"bar"), - root + "/c/1": Buffer.from_bytes(b"\x01"), - } - - assert await _collect_aiterator(store.list_dir("")) == () - assert await _collect_aiterator(store.list_dir(root)) == () - - await store._set_dict(store_dict) - - keys_observed = await _collect_aiterator(store.list_dir(root)) - keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()} - - assert sorted(keys_observed) == sorted(keys_expected) - - keys_observed = await _collect_aiterator(store.list_dir(root + "/")) - assert sorted(keys_expected) == sorted(keys_observed) From d15be9a1a4106901ab7f0abddb5245e2a28be32b Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sat, 3 Aug 2024 12:44:18 +0200 Subject: [PATCH 06/22] bring in correct store list_dir implementations --- src/zarr/store/remote.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 0fd013cb4b..0b9e3bb8ce 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -205,6 +205,6 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: yield onefile async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - glob_str = f"{self.path}/{prefix.rstrip('/')}**" - for onefile in await self._fs._glob(glob_str): - yield onefile + find_str = "/".join([self.path, prefix]) + for onefile in await self._fs._find(find_str): + yield onefile.removeprefix(find_str) From 962ffed90bf0223a99b4e9e401e489fcde74aad0 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 12 Aug 2024 23:43:29 +0200 Subject: [PATCH 07/22] bump numcodecs to dodge zstd exception --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b31b5c8dbf..f06305c14a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ 'asciitree', 'numpy>=1.24', 'fasteners', - 'numcodecs>=0.10.0', + 'numcodecs>=0.13.0', 'fsspec>2024', 'crc32c', 'typing_extensions', From 5c98ab4c2e8b673cfc307cd159d9f9a18639ed8b Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 12 Aug 2024 23:45:40 +0200 Subject: [PATCH 08/22] remove store._set_dict, and add _set_many and get_many instead --- src/zarr/abc/store.py | 37 +++++++++++++++++++++++++++---------- src/zarr/store/utils.py | 6 ++++++ src/zarr/testing/store.py | 37 +++++++++++++++++++++++++++++-------- tests/v3/conftest.py | 10 ++++++++++ tests/v3/test_buffer.py | 8 ++++---- 5 files changed, 76 insertions(+), 22 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 95d12943b9..9e127efcac 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,12 +1,21 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Mapping -from typing import Any, NamedTuple, Protocol, runtime_checkable +from asyncio import gather +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Iterable + from typing import Any, TypeAlias -from typing_extensions import Self + from typing_extensions import Self from zarr.buffer import Buffer, BufferPrototype from zarr.common import AccessModeLiteral, BytesLike +ByteRangeRequest: TypeAlias = tuple[int | None, int | None] | None + class AccessMode(NamedTuple): readonly: bool @@ -76,7 +85,7 @@ async def get( self, key: str, prototype: BufferPrototype, - byte_range: tuple[int | None, int | None] | None = None, + byte_range: ByteRangeRequest | None = None, ) -> Buffer | None: """Retrieve the value associated with a given key. @@ -95,13 +104,13 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, tuple[int | None, int | None]]], + key_ranges: list[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: """Retrieve possibly partial values from given key_ranges. Parameters ---------- - key_ranges : list[tuple[str, tuple[int, int]]] + key_ranges : list[tuple[str, tuple[int | None, int | None]]] Ordered set of key, range pairs, a key may occur multiple times with different ranges Returns @@ -221,14 +230,22 @@ def close(self) -> None: self._is_open = False pass - async def _set_dict(self, dict: Mapping[str, Buffer]) -> None: + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: """ - Insert objects into storage as defined by a prefix: value mapping. + Insert a collection of objects into storage. """ - for key, value in dict.items(): - await self.set(key, value) + await gather(*(self.set(key, value) for key, value in values)) return None + async def _get_many( + self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest]] + ) -> AsyncGenerator[Buffer | None, None]: + """ + Retrieve a collection of objects from storage. + """ + for req in requests: + yield await self.get(*req) + @runtime_checkable class ByteGetter(Protocol): diff --git a/src/zarr/store/utils.py b/src/zarr/store/utils.py index 17c9234221..7e5f6b352f 100644 --- a/src/zarr/store/utils.py +++ b/src/zarr/store/utils.py @@ -1,3 +1,9 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + pass from zarr.buffer import Buffer diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index ba37dda625..850eb0edb1 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -5,7 +5,7 @@ from zarr.abc.store import AccessMode, Store from zarr.buffer import Buffer, default_buffer_prototype from zarr.store.utils import _normalize_interval_index -from zarr.sync import _collect_aiterator +from zarr.sync import _collect_aiterator, collect_aiterator from zarr.testing.utils import assert_bytes_equal S = TypeVar("S", bound=Store) @@ -92,6 +92,27 @@ async def test_get( expected = data_buf[start : start + length] assert_bytes_equal(observed, expected) + async def test_get_many(self, store: S) -> None: + """ + Ensure that multiple keys can be retrieved at once with the _get_many method. + """ + keys = tuple(map(str, range(10))) + values = tuple(f"{k}".encode() for k in keys) + for k, v in zip(keys, values, strict=False): + self.set(store, k, Buffer.from_bytes(v)) + observed_buffers = collect_aiterator( + store._get_many( + zip( + keys, + (default_buffer_prototype(),) * len(keys), + (None,) * len(keys), + strict=False, + ) + ) + ) + observed_values = tuple(b.to_bytes() for b in observed_buffers) # type: ignore + assert observed_values == values + @pytest.mark.parametrize("key", ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"]) @pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""]) async def test_set(self, store: S, key: str, data: bytes) -> None: @@ -104,15 +125,15 @@ async def test_set(self, store: S, key: str, data: bytes) -> None: observed = self.get(store, key) assert_bytes_equal(observed, data_buf) - async def test_set_dict(self, store: S) -> None: + async def test_set_many(self, store: S) -> None: """ - Test that a dict of key : value pairs can be inserted into the store via the - `_set_dict` method. + Test that a collection of key : value pairs can be inserted into the store via the + `_set_many` method. """ keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"] data_buf = [Buffer.from_bytes(k.encode()) for k in keys] store_dict = dict(zip(keys, data_buf, strict=True)) - await store._set_dict(store_dict) + await store._set_many(store_dict.items()) for k, v in store_dict.items(): assert self.get(store, k).to_bytes() == v.to_bytes() @@ -185,7 +206,7 @@ async def test_list(self, store: S) -> None: prefix + "/zarr.json": data, **{prefix + f"/c/{idx}": data for idx in range(10)}, } - await store._set_dict(store_dict) + await store._set_many(store_dict.items()) expected_sorted = sorted(store_dict.keys()) observed = await _collect_aiterator(store.list()) observed_sorted = sorted(observed) @@ -201,7 +222,7 @@ async def test_list_prefix(self, store: S) -> None: data = Buffer.from_bytes(b"") fname = "zarr.json" store_dict = {p + fname: data for p in prefixes} - await store._set_dict(store_dict) + await store._set_many(store_dict.items()) for p in prefixes: observed = tuple(sorted(await _collect_aiterator(store.list_prefix(p)))) expected: tuple[str, ...] = () @@ -221,7 +242,7 @@ async def test_list_dir(self, store: S) -> None: assert await _collect_aiterator(store.list_dir("")) == () assert await _collect_aiterator(store.list_dir(root)) == () - await store._set_dict(store_dict) + await store._set_many(store_dict.items()) keys_observed = await _collect_aiterator(store.list_dir(root)) keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()} diff --git a/tests/v3/conftest.py b/tests/v3/conftest.py index 0a672d1f2e..6a4cc1aa80 100644 --- a/tests/v3/conftest.py +++ b/tests/v3/conftest.py @@ -122,6 +122,16 @@ def array_fixture(request: pytest.FixtureRequest) -> np.ndarray: ) +@pytest.fixture(params=(2, 3)) +def zarr_format(request: pytest.FixtureRequest) -> ZarrFormat: + if request.param == 2: + return 2 + elif request.param == 3: + return 3 + msg = f"Invalid zarr format requested. Got {request.param}, expected on of (2,3)." + raise ValueError(msg) + + settings.register_profile( "ci", max_examples=1000, diff --git a/tests/v3/test_buffer.py b/tests/v3/test_buffer.py index d53e98d42d..1ca382c4ef 100644 --- a/tests/v3/test_buffer.py +++ b/tests/v3/test_buffer.py @@ -20,14 +20,14 @@ ) -def test_nd_array_like(xp): +def test_nd_array_like(xp) -> None: ary = xp.arange(10) assert isinstance(ary, ArrayLike) assert isinstance(ary, NDArrayLike) @pytest.mark.asyncio -async def test_async_array_prototype(): +async def test_async_array_prototype() -> None: """Test the use of a custom buffer prototype""" expect = np.zeros((9, 9), dtype="uint16", order="F") @@ -53,7 +53,7 @@ async def test_async_array_prototype(): @pytest.mark.asyncio -async def test_codecs_use_of_prototype(): +async def test_codecs_use_of_prototype() -> None: expect = np.zeros((10, 10), dtype="uint16", order="F") a = await AsyncArray.create( StorePath(StoreExpectingTestBuffer(mode="w")) / "test_codecs_use_of_prototype", @@ -84,7 +84,7 @@ async def test_codecs_use_of_prototype(): assert np.array_equal(expect, got) -def test_numpy_buffer_prototype(): +def test_numpy_buffer_prototype() -> None: buffer = numpy_buffer_prototype().buffer.create_zero_length() ndbuffer = numpy_buffer_prototype().nd_buffer.create(shape=(1, 2), dtype=np.dtype("int64")) assert isinstance(buffer.as_array_like(), np.ndarray) From 9e64fa8857143e477adfae3f3ae4f9ad3904c068 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 13 Aug 2024 13:59:35 +0200 Subject: [PATCH 09/22] update deprecation warning template --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f06305c14a..ae57d829ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -223,5 +223,5 @@ filterwarnings = [ "error:::zarr.*", "ignore:PY_SSIZE_T_CLEAN will be required.*:DeprecationWarning", "ignore:The loop argument is deprecated since Python 3.8.*:DeprecationWarning", - "ignore:.*is transitional and will be removed.*:DeprecationWarning", + "ignore:.*may be removed in an early zarr-python v3 release.:DeprecationWarning", ] From a4b46968b901b840548e1f10f7f66f6af3fb6826 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 13 Aug 2024 14:03:52 +0200 Subject: [PATCH 10/22] add a type annotation --- tests/v3/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v3/conftest.py b/tests/v3/conftest.py index 6a4cc1aa80..fc694166d8 100644 --- a/tests/v3/conftest.py +++ b/tests/v3/conftest.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Generator, Iterator from types import ModuleType from typing import TYPE_CHECKING @@ -99,7 +99,7 @@ def xp(request: pytest.FixtureRequest) -> Iterator[ModuleType]: @pytest.fixture(autouse=True) -def reset_config(): +def reset_config() -> Generator[None, None, None]: config.reset() yield config.reset() From 04b1d6a716945d293b43867738359b2bdbb2fea2 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 13 Aug 2024 14:08:10 +0200 Subject: [PATCH 11/22] refactor chunk iterators. they are not properties any more, just methods, and they can take an origin kwarg --- src/zarr/array.py | 57 +++++++++++++++++++-------------------- src/zarr/indexing.py | 61 +++++++++++++++++++++++++++++++++++------- tests/v3/test_array.py | 4 +-- 3 files changed, 80 insertions(+), 42 deletions(-) diff --git a/src/zarr/array.py b/src/zarr/array.py index 83042e775e..a6e38fed05 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -12,7 +12,13 @@ from asyncio import gather from collections.abc import Iterable, Iterator from dataclasses import dataclass, field, replace -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, cast + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Any, Literal + + from zarr.common import JSON, ChunkCoords, ZarrFormat import numpy as np import numpy.typing as npt @@ -27,12 +33,9 @@ from zarr.codecs import BytesCodec from zarr.codecs._v2 import V2Compressor, V2Filters from zarr.common import ( - JSON, ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON, - ChunkCoords, - ZarrFormat, concurrent_map, product, ) @@ -53,13 +56,13 @@ OrthogonalSelection, Selection, VIndex, + _iter_grid, ceildiv, check_fields, check_no_multi_fields, is_pure_fancy_indexing, is_pure_orthogonal_indexing, is_scalar, - iter_grid, pop_fields, ) from zarr.metadata import ArrayMetadata, ArrayV2Metadata, ArrayV3Metadata @@ -443,9 +446,7 @@ def basename(self) -> str | None: return None @property - @deprecated( - "cdata_shape is transitional and will be removed in an early zarr-python v3 release." - ) + @deprecated("AsyncArray.cdata_shape may be removed in an early zarr-python v3 release.") def cdata_shape(self) -> ChunkCoords: """ The shape of the chunk grid for this array. @@ -453,34 +454,32 @@ def cdata_shape(self) -> ChunkCoords: return tuple(ceildiv(s, c) for s, c in zip(self.shape, self.chunks, strict=False)) @property - @deprecated("nchunks is transitional and will be removed in an early zarr-python v3 release.") + @deprecated("AsyncArray.nchunks may be removed in an early zarr-python v3 release.") def nchunks(self) -> int: """ The number of chunks in the stored representation of this array. """ return product(self.cdata_shape) - @property - def _iter_chunk_coords(self) -> Iterator[ChunkCoords]: + def _iter_chunk_coords(self, origin: Sequence[int] | None = None) -> Iterator[ChunkCoords]: """ - Produce an iterator over the coordinates of each chunk, in chunk grid space. + Produce an iterator over the coordinates of each chunk, in chunk grid space, relative to + an optional origin. """ - return iter_grid(self.cdata_shape) + return _iter_grid(self.cdata_shape, origin=origin) - @property - def _iter_chunk_keys(self) -> Iterator[str]: + def _iter_chunk_keys(self, origin: Sequence[int] | None = None) -> Iterator[str]: """ - Return an iterator over the keys of each chunk. + Return an iterator over the storage keys of each chunk, relative to an optional origin. """ - for k in self._iter_chunk_coords: + for k in self._iter_chunk_coords(origin=origin): yield self.metadata.encode_chunk_key(k) - @property def _iter_chunk_regions(self) -> Iterator[tuple[slice, ...]]: """ Iterate over the regions spanned by each chunk. """ - for cgrid_position in self._iter_chunk_coords: + for cgrid_position in self._iter_chunk_coords(): out: tuple[slice, ...] = () for c_pos, c_shape in zip(cgrid_position, self.chunks, strict=False): start = c_pos * c_shape @@ -811,12 +810,12 @@ def nchunks(self) -> int: """ return self._async_array.nchunks - @property - def _iter_chunks(self) -> Iterator[ChunkCoords]: + def _iter_chunks(self, origin: Sequence[int] | None = None) -> Iterator[ChunkCoords]: """ - Produce an iterator over the coordinates of each chunk, in chunk grid space. + Produce an iterator over the coordinates of each chunk, in chunk grid space, relative + to an optional origin. """ - yield from self._async_array._iter_chunk_coords + yield from self._async_array._iter_chunk_coords(origin=origin) @property def nbytes(self) -> int: @@ -825,19 +824,17 @@ def nbytes(self) -> int: """ return self._async_array.nbytes - @property - def _iter_chunk_keys(self) -> Iterator[str]: + def _iter_chunk_keys(self, origin: Sequence[int] | None = None) -> Iterator[str]: """ - Return an iterator over the keys of each chunk. + Return an iterator over the keys of each chunk, relative to an optional origin """ - yield from self._async_array._iter_chunk_keys + yield from self._async_array._iter_chunk_keys(origin=origin) - @property def _iter_chunk_regions(self) -> Iterator[tuple[slice, ...]]: """ Iterate over the regions spanned by each chunk. """ - yield from self._async_array._iter_chunk_regions + yield from self._async_array._iter_chunk_regions() def __array__( self, dtype: npt.DTypeLike | None = None, copy: bool | None = None @@ -2179,7 +2176,7 @@ def chunks_initialized(array: Array) -> tuple[str, ...]: ) out: list[str] = [] - for chunk_key in array._iter_chunk_keys: + for chunk_key in array._iter_chunk_keys(): if chunk_key in store_contents: out.append(chunk_key) diff --git a/src/zarr/indexing.py b/src/zarr/indexing.py index 91ab571b36..7e53a0ba62 100644 --- a/src/zarr/indexing.py +++ b/src/zarr/indexing.py @@ -4,15 +4,17 @@ import math import numbers import operator -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Iterator, Sequence from dataclasses import dataclass from enum import Enum from functools import reduce from types import EllipsisType from typing import ( TYPE_CHECKING, + Literal, NamedTuple, Protocol, + TypeAlias, TypeGuard, TypeVar, cast, @@ -87,12 +89,35 @@ def ceildiv(a: float, b: float) -> int: return math.ceil(a / b) -def iter_grid(shape: Iterable[int]) -> Iterator[ChunkCoords]: +_ArrayIndexingOrder: TypeAlias = Literal["lexicographic"] + + +def _iter_grid( + shape: Sequence[int], + *, + origin: Sequence[int] | None = None, + order: _ArrayIndexingOrder = "lexicographic", +) -> Iterator[ChunkCoords]: """ - Iterate over the elements of grid. + Iterate over the elements of grid of integers. + + Takes a grid shape expressed as a sequence of integers and an optional origin and + yields tuples bounded by [origin, origin + grid_shape]. + + Parameters + --------- + shape: Sequence[int] + The size of the domain to iterate over. + origin: Sequence[int] | None, default=None + The first coordinate of the domain. + order: Literal["lexicographic"], default="lexicographic" + The linear indexing order to use. + + Returns + ------- - Takes a grid shape expressed as an iterable of ints and - yields tuples bounded by that grid shape in lexicographic order. + itertools.product object + An iterator over tuples of integers Examples -------- @@ -102,12 +127,28 @@ def iter_grid(shape: Iterable[int]) -> Iterator[ChunkCoords]: >>> tuple(iter_grid((2,3))) ((0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)) - Parameters - ---------- - shape: Iterable[int] - The shape of the grid to iterate over. + >>> tuple(iter_grid((2,3)), origin=(1,1)) + ((1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)) """ - yield from itertools.product(*(map(range, shape))) + if origin is None: + origin_parsed = (0,) * len(shape) + else: + if len(origin) != len(shape): + msg = ( + "Shape and origin parameters must have the same length." + f"Got {len(shape)} elements in shape, but {len(origin)} elements in origin." + ) + raise ValueError(msg) + origin_parsed = tuple(origin) + + if order == "lexicographic": + yield from itertools.product( + *(range(o, o + s) for o, s in zip(origin_parsed, shape, strict=True)) + ) + + else: + msg = f"Indexing order {order} is not supported at this time." # type: ignore[unreachable] + raise NotImplementedError(msg) def is_integer(x: Any) -> TypeGuard[int]: diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 0eeeeece0d..d1203b77bc 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -149,10 +149,10 @@ def test_nchunks_initialized(store: LocalStore | MemoryStore, zarr_format: ZarrF assert array.nchunks == num_chunks assert nchunks_initialized(array) == 0 - for idx, region in enumerate(array._iter_chunk_regions): + for idx, region in enumerate(array._iter_chunk_regions()): array[region] = 1 assert nchunks_initialized(array) == idx + 1 - for idx, key in enumerate(array._iter_chunk_keys): + for idx, key in enumerate(array._iter_chunk_keys()): sync((array.store_path / key).delete()) assert nchunks_initialized(array) == array.nchunks - (idx + 1) From b7c1a56ee89c25d931488f8615724ddad66bf13c Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 19 Sep 2024 19:40:59 +0200 Subject: [PATCH 12/22] _get_many returns tuple[str, buffer] --- src/zarr/abc/store.py | 8 +++++--- src/zarr/testing/store.py | 5 +++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index d1a2ee6940..4aee2528cf 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -263,12 +263,14 @@ def close(self) -> None: async def _get_many( self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest]] - ) -> AsyncGenerator[Buffer | None, None]: + ) -> AsyncGenerator[tuple[str, Buffer | None], None]: """ - Retrieve a collection of objects from storage. + Retrieve a collection of objects from storage. In general this method does not guarantee + that objects will be retrieved in the order in which they were requested, so this method + yields tuple[str, Buffer | None] instead of just Buffer | None """ for req in requests: - yield await self.get(*req) + yield (req[0], await self.get(*req)) @runtime_checkable diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index e3fba54883..267ca9ead6 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -129,8 +129,9 @@ async def test_get_many(self, store: S) -> None: ) ) ) - observed_values = tuple(b.to_bytes() for b in observed_buffers) # type: ignore - assert observed_values == values + observed_kvs = sorted(tuple((k, b.to_bytes()) for k, b in observed_buffers)) + expected_kvs = sorted(tuple((k, b) for k, b in zip(keys, values, strict=False))) + assert observed_kvs == expected_kvs @pytest.mark.parametrize("key", ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"]) @pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""]) From 44bed5c8255095b83f5d2e3cd7fa8d5fcda34bee Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 19 Sep 2024 19:48:39 +0200 Subject: [PATCH 13/22] stricter store types --- src/zarr/abc/store.py | 16 +++++++++------- src/zarr/store/local.py | 5 +++-- src/zarr/store/memory.py | 4 ++-- src/zarr/store/remote.py | 6 +++--- src/zarr/store/zip.py | 6 +++--- 5 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 4aee2528cf..6808e982f1 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -16,7 +16,7 @@ __all__ = ["Store", "AccessMode", "ByteGetter", "ByteSetter", "set_or_delete"] -ByteRangeRequest: TypeAlias = tuple[int | None, int | None] | None +ByteRangeRequest: TypeAlias = tuple[int | None, int | None] class AccessMode(NamedTuple): @@ -121,7 +121,7 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, ByteRangeRequest]], + key_ranges: Iterable[tuple[str, ByteRangeRequest | None]], ) -> list[Buffer | None]: """Retrieve possibly partial values from given key_ranges. @@ -197,7 +197,9 @@ def supports_partial_writes(self) -> bool: ... @abstractmethod - async def set_partial_values(self, key_start_values: list[tuple[str, int, BytesLike]]) -> None: + async def set_partial_values( + self, key_start_values: Iterable[tuple[str, int, BytesLike]] + ) -> None: """Store values at a given key, starting at byte range_start. Parameters @@ -262,7 +264,7 @@ def close(self) -> None: self._is_open = False async def _get_many( - self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest]] + self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: """ Retrieve a collection of objects from storage. In general this method does not guarantee @@ -276,17 +278,17 @@ async def _get_many( @runtime_checkable class ByteGetter(Protocol): async def get( - self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None + self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None ) -> Buffer | None: ... @runtime_checkable class ByteSetter(Protocol): async def get( - self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None + self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None ) -> Buffer | None: ... - async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: ... + async def set(self, value: Buffer, byte_range: ByteRangeRequest = None) -> None: ... async def delete(self) -> None: ... diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index c78837586f..4702a65b13 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -3,10 +3,11 @@ import io import os import shutil +from collections.abc import Iterable from pathlib import Path from typing import TYPE_CHECKING -from zarr.abc.store import Store +from zarr.abc.store import ByteRangeRequest, Store from zarr.core.buffer import Buffer from zarr.core.common import concurrent_map, to_thread @@ -127,7 +128,7 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, tuple[int | None, int | None]]], + key_ranges: Iterable[ByteRangeRequest], ) -> list[Buffer | None]: """ Read byte ranges from multiple keys. diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index e304419768..a4ad3d2ad6 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -3,7 +3,7 @@ from collections.abc import AsyncGenerator, MutableMapping from typing import TYPE_CHECKING, Any -from zarr.abc.store import Store +from zarr.abc.store import ByteRangeRequest, Store from zarr.core.buffer import Buffer, gpu from zarr.core.common import concurrent_map from zarr.store._utils import _normalize_interval_index @@ -80,7 +80,7 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, tuple[int | None, int | None]]], + key_ranges: list[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: # All the key-ranges arguments goes with the same prototype async def _get(key: str, byte_range: tuple[int, int | None]) -> Buffer | None: diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 084ef986b1..a1463d9c09 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -4,7 +4,7 @@ import fsspec -from zarr.abc.store import Store +from zarr.abc.store import ByteRangeRequest, Store from zarr.store.common import _dereference_path if TYPE_CHECKING: @@ -105,7 +105,7 @@ async def get( self, key: str, prototype: BufferPrototype, - byte_range: tuple[int | None, int | None] | None = None, + byte_range: ByteRangeRequest = None, ) -> Buffer | None: if not self._is_open: await self._open() @@ -172,7 +172,7 @@ async def exists(self, key: str) -> bool: async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, tuple[int | None, int | None]]], + key_ranges: list[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: if key_ranges: paths, starts, stops = zip( diff --git a/src/zarr/store/zip.py b/src/zarr/store/zip.py index ea31ad934a..682d287240 100644 --- a/src/zarr/store/zip.py +++ b/src/zarr/store/zip.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Literal -from zarr.abc.store import Store +from zarr.abc.store import ByteRangeRequest, Store from zarr.core.buffer import Buffer, BufferPrototype if TYPE_CHECKING: @@ -128,7 +128,7 @@ def _get( self, key: str, prototype: BufferPrototype, - byte_range: tuple[int | None, int | None] | None = None, + byte_range: ByteRangeRequest = None, ) -> Buffer | None: try: with self._zf.open(key) as f: # will raise KeyError @@ -161,7 +161,7 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, tuple[int | None, int | None]]], + key_ranges: list[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: out = [] with self._lock: From 2db860b9449303e979fdd0e01bf311a5aa087adf Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 23 Sep 2024 23:24:47 +0200 Subject: [PATCH 14/22] fix types --- src/zarr/abc/store.py | 11 +++++------ src/zarr/codecs/sharding.py | 6 +++--- src/zarr/core/common.py | 2 +- src/zarr/store/common.py | 6 +++--- src/zarr/store/local.py | 9 +++++---- src/zarr/store/memory.py | 9 ++++----- src/zarr/store/remote.py | 10 ++++++---- src/zarr/store/zip.py | 10 +++++----- 8 files changed, 32 insertions(+), 31 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 6808e982f1..2890e93aa0 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -2,17 +2,16 @@ from abc import ABC, abstractmethod from asyncio import gather -from collections.abc import AsyncGenerator, Iterable from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, runtime_checkable if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import AsyncGenerator, Iterable from typing import Any, TypeAlias from typing_extensions import Self -from zarr.core.buffer import Buffer, BufferPrototype -from zarr.core.common import AccessModeLiteral, BytesLike + from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.common import AccessModeLiteral, BytesLike __all__ = ["Store", "AccessMode", "ByteGetter", "ByteSetter", "set_or_delete"] @@ -121,7 +120,7 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: Iterable[tuple[str, ByteRangeRequest | None]], + key_ranges: Iterable[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: """Retrieve possibly partial values from given key_ranges. @@ -288,7 +287,7 @@ async def get( self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None ) -> Buffer | None: ... - async def set(self, value: Buffer, byte_range: ByteRangeRequest = None) -> None: ... + async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None: ... async def delete(self) -> None: ... diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 3ae51ce54b..1c89a67e80 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -17,7 +17,7 @@ Codec, CodecPipeline, ) -from zarr.abc.store import ByteGetter, ByteSetter +from zarr.abc.store import ByteGetter, ByteRangeRequest, ByteSetter from zarr.codecs.bytes import BytesCodec from zarr.codecs.crc32c_ import Crc32cCodec from zarr.core.array_spec import ArraySpec @@ -78,7 +78,7 @@ class _ShardingByteGetter(ByteGetter): chunk_coords: ChunkCoords async def get( - self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None + self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None ) -> Buffer | None: assert byte_range is None, "byte_range is not supported within shards" assert ( @@ -91,7 +91,7 @@ async def get( class _ShardingByteSetter(_ShardingByteGetter, ByteSetter): shard_dict: ShardMutableMapping - async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None: assert byte_range is None, "byte_range is not supported within shards" self.shard_dict[self.chunk_coords] = value diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 8ebe5160bd..379a24baac 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -47,7 +47,7 @@ def product(tup: ChunkCoords) -> int: async def concurrent_map( - items: list[T], func: Callable[..., Awaitable[V]], limit: int | None = None + items: Iterable[T], func: Callable[..., Awaitable[V]], limit: int | None = None ) -> list[V]: if limit is None: return await asyncio.gather(*[func(*item) for item in items]) diff --git a/src/zarr/store/common.py b/src/zarr/store/common.py index 196479dd67..3388235710 100644 --- a/src/zarr/store/common.py +++ b/src/zarr/store/common.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Literal -from zarr.abc.store import AccessMode, Store +from zarr.abc.store import AccessMode, ByteRangeRequest, Store from zarr.core.buffer import Buffer, default_buffer_prototype from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZGROUP_JSON, ZarrFormat from zarr.errors import ContainsArrayAndGroupError, ContainsArrayError, ContainsGroupError @@ -38,13 +38,13 @@ def __init__(self, store: Store, path: str | None = None): async def get( self, prototype: BufferPrototype | None = None, - byte_range: tuple[int, int | None] | None = None, + byte_range: ByteRangeRequest | None = None, ) -> Buffer | None: if prototype is None: prototype = default_buffer_prototype() return await self.store.get(self.path, prototype=prototype, byte_range=byte_range) - async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None: if byte_range is not None: raise NotImplementedError("Store.set does not have partial writes yet") await self.store.set(self.path, value) diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index 4702a65b13..44ec9f21ee 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -3,7 +3,6 @@ import io import os import shutil -from collections.abc import Iterable from pathlib import Path from typing import TYPE_CHECKING @@ -12,7 +11,7 @@ from zarr.core.common import concurrent_map, to_thread if TYPE_CHECKING: - from collections.abc import AsyncGenerator + from collections.abc import AsyncGenerator, Iterable from zarr.core.buffer import BufferPrototype from zarr.core.common import AccessModeLiteral @@ -128,7 +127,7 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: Iterable[ByteRangeRequest], + key_ranges: Iterable[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: """ Read byte ranges from multiple keys. @@ -158,7 +157,9 @@ async def set(self, key: str, value: Buffer) -> None: path = self.root / key await to_thread(_put, path, value) - async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None: + async def set_partial_values( + self, key_start_values: Iterable[tuple[str, int, bytes | bytearray | memoryview]] + ) -> None: self._check_writable() args = [] for key, start, value in key_start_values: diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 588fc4c885..53f216264f 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import AsyncGenerator, MutableMapping from typing import TYPE_CHECKING from zarr.abc.store import ByteRangeRequest, Store @@ -9,7 +8,7 @@ from zarr.store._utils import _normalize_interval_index if TYPE_CHECKING: - from collections.abc import AsyncGenerator, MutableMapping + from collections.abc import AsyncGenerator, Iterable, MutableMapping from zarr.core.buffer import BufferPrototype from zarr.core.common import AccessModeLiteral @@ -74,10 +73,10 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, ByteRangeRequest]], + key_ranges: Iterable[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: # All the key-ranges arguments goes with the same prototype - async def _get(key: str, byte_range: tuple[int, int | None]) -> Buffer | None: + async def _get(key: str, byte_range: ByteRangeRequest) -> Buffer | None: return await self.get(key, prototype=prototype, byte_range=byte_range) vals = await concurrent_map(key_ranges, _get, limit=None) @@ -108,7 +107,7 @@ async def delete(self, key: str) -> None: except KeyError: pass # Q(JH): why not raise? - async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None: + async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None: raise NotImplementedError async def list(self) -> AsyncGenerator[str, None]: diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 8db63df7c4..46fbd82776 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -8,7 +8,7 @@ from zarr.store.common import _dereference_path if TYPE_CHECKING: - from collections.abc import AsyncGenerator + from collections.abc import AsyncGenerator, Iterable from fsspec.asyn import AsyncFileSystem @@ -109,7 +109,7 @@ async def get( self, key: str, prototype: BufferPrototype, - byte_range: ByteRangeRequest = None, + byte_range: ByteRangeRequest | None = None, ) -> Buffer | None: if not self._is_open: await self._open() @@ -176,7 +176,7 @@ async def exists(self, key: str) -> bool: async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, ByteRangeRequest]], + key_ranges: Iterable[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: if key_ranges: paths, starts, stops = zip( @@ -202,7 +202,9 @@ async def get_partial_values( return [None if isinstance(r, Exception) else prototype.buffer.from_bytes(r) for r in res] - async def set_partial_values(self, key_start_values: list[tuple[str, int, BytesLike]]) -> None: + async def set_partial_values( + self, key_start_values: Iterable[tuple[str, int, BytesLike]] + ) -> None: raise NotImplementedError async def list(self) -> AsyncGenerator[str, None]: diff --git a/src/zarr/store/zip.py b/src/zarr/store/zip.py index 682d287240..ee57ab590d 100644 --- a/src/zarr/store/zip.py +++ b/src/zarr/store/zip.py @@ -11,7 +11,7 @@ from zarr.core.buffer import Buffer, BufferPrototype if TYPE_CHECKING: - from collections.abc import AsyncGenerator + from collections.abc import AsyncGenerator, Iterable ZipStoreAccessModeLiteral = Literal["r", "w", "a"] @@ -128,7 +128,7 @@ def _get( self, key: str, prototype: BufferPrototype, - byte_range: ByteRangeRequest = None, + byte_range: ByteRangeRequest | None = None, ) -> Buffer | None: try: with self._zf.open(key) as f: # will raise KeyError @@ -151,7 +151,7 @@ async def get( self, key: str, prototype: BufferPrototype, - byte_range: tuple[int | None, int | None] | None = None, + byte_range: ByteRangeRequest | None = None, ) -> Buffer | None: assert isinstance(key, str) @@ -161,7 +161,7 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, ByteRangeRequest]], + key_ranges: Iterable[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: out = [] with self._lock: @@ -188,7 +188,7 @@ async def set(self, key: str, value: Buffer) -> None: with self._lock: self._set(key, value) - async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None: + async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None: raise NotImplementedError async def delete(self, key: str) -> None: From b5e08e862bdd5cf887983923d54143759ab16d28 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 24 Sep 2024 12:10:26 +0200 Subject: [PATCH 15/22] lint --- src/zarr/core/array.py | 3 +-- src/zarr/store/_utils.py | 5 ++++- src/zarr/testing/store.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 9a612ceb5b..b175e9776a 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -2,7 +2,6 @@ import json from asyncio import gather -from collections.abc import Iterator, Sequence from dataclasses import dataclass, field, replace from typing import TYPE_CHECKING, Any, Literal, cast @@ -71,7 +70,7 @@ ) if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Iterator, Sequence from zarr.abc.codec import Codec, CodecPipeline from zarr.core.metadata.common import ArrayMetadata diff --git a/src/zarr/store/_utils.py b/src/zarr/store/_utils.py index 99ee606752..cbc9c42bbd 100644 --- a/src/zarr/store/_utils.py +++ b/src/zarr/store/_utils.py @@ -1,6 +1,9 @@ from __future__ import annotations -from zarr.core.buffer import Buffer +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from zarr.core.buffer import Buffer def _normalize_interval_index( diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index a27130ac26..a346c860d1 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -129,7 +129,7 @@ async def test_get_many(self, store: S) -> None: ) ) ) - observed_kvs = sorted(tuple((k, b.to_bytes()) for k, b in observed_buffers)) + observed_kvs = sorted(tuple((k, b.to_bytes()) for k, b in observed_buffers)) # type: ignore[union-attr] expected_kvs = sorted(tuple((k, b) for k, b in zip(keys, values, strict=False))) assert observed_kvs == expected_kvs From 43743e14510ca6cacb1914eb76a24a479093134f Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 24 Sep 2024 12:23:49 +0200 Subject: [PATCH 16/22] remove deprecation warnings --- src/zarr/core/array.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index b175e9776a..c24758cf67 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -7,7 +7,6 @@ import numpy as np import numpy.typing as npt -from typing_extensions import deprecated from zarr._compat import _deprecate_positional_args from zarr.abc.codec import Codec, CodecPipeline @@ -449,7 +448,6 @@ def basename(self) -> str | None: return None @property - @deprecated("AsyncArray.cdata_shape may be removed in an early zarr-python v3 release.") def cdata_shape(self) -> ChunkCoords: """ The shape of the chunk grid for this array. @@ -457,7 +455,6 @@ def cdata_shape(self) -> ChunkCoords: return tuple(ceildiv(s, c) for s, c in zip(self.shape, self.chunks, strict=False)) @property - @deprecated("AsyncArray.nchunks may be removed in an early zarr-python v3 release.") def nchunks(self) -> int: """ The number of chunks in the stored representation of this array. @@ -806,9 +803,6 @@ def fill_value(self) -> Any: return self.metadata.fill_value @property - @deprecated( - "cdata_shape is transitional and will be removed in an early zarr-python v3 release." - ) def cdata_shape(self) -> ChunkCoords: """ The shape of the chunk grid for this array. @@ -816,7 +810,6 @@ def cdata_shape(self) -> ChunkCoords: return tuple(ceildiv(s, c) for s, c in zip(self.shape, self.chunks, strict=False)) @property - @deprecated("nchunks is transitional and will be removed in an early zarr-python v3 release.") def nchunks(self) -> int: """ The number of chunks in the stored representation of this array. @@ -2182,18 +2175,19 @@ def info(self) -> None: ) -@deprecated( - "nchunks_initialized is transitional and will be removed in an early zarr-python v3 release." -) def nchunks_initialized(array: Array) -> int: + """ + Calculate the number of chunks that have been initialized, i.e. the number of chunks that have + been persisted to the storage backend. + """ return len(chunks_initialized(array)) def chunks_initialized(array: Array) -> tuple[str, ...]: """ - Return the keys of all the chunks that exist in storage. + Return the keys of the chunks that have been persisted to the storage backend. """ - # todo: make this compose with the underlying async iterator + # TODO: make this compose with the underlying async iterator store_contents = list( collect_aiterator(array.store_path.store.list_prefix(prefix=array.store_path.path)) ) From f65a6e808c476c1e13734cc3e9baee5b38260332 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 24 Sep 2024 12:34:56 +0200 Subject: [PATCH 17/22] fix zip list_prefix --- src/zarr/store/zip.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/zarr/store/zip.py b/src/zarr/store/zip.py index ee57ab590d..8afa983c1c 100644 --- a/src/zarr/store/zip.py +++ b/src/zarr/store/zip.py @@ -209,9 +209,21 @@ async def list(self) -> AsyncGenerator[str, None]: yield key async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. + + Parameters + ---------- + prefix : str + + Returns + ------- + AsyncGenerator[str, None] + """ async for key in self.list(): if key.startswith(prefix): - yield key + yield key.removeprefix(prefix) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: if prefix.endswith("/"): From df6f9a72c500addadac947004646cd4b034be37c Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 24 Sep 2024 14:54:12 +0200 Subject: [PATCH 18/22] tests for nchunks_initialized, chunks_initialized; add selection_shape kwarg to grid iteration; make chunk grid iterators consistent for array and async array --- src/zarr/core/array.py | 187 ++++++++++++++++++++++++++++++++++---- src/zarr/core/indexing.py | 44 ++++++--- tests/v3/test_array.py | 55 +++++++++++ tests/v3/test_indexing.py | 52 +++++++++++ 4 files changed, 304 insertions(+), 34 deletions(-) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index c24758cf67..70f7bba036 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -461,25 +461,83 @@ def nchunks(self) -> int: """ return product(self.cdata_shape) - def _iter_chunk_coords(self, origin: Sequence[int] | None = None) -> Iterator[ChunkCoords]: + @property + def nchunks_initialized(self) -> int: + """ + The number of chunks that have been persisted in storage. """ - Produce an iterator over the coordinates of each chunk, in chunk grid space, relative to - an optional origin. + return nchunks_initialized(self) + + def _iter_chunk_coords( + self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None + ) -> Iterator[ChunkCoords]: """ - return _iter_grid(self.cdata_shape, origin=origin) + Create an iterator over the coordinates of chunks in chunk grid space. If the `origin` + keyword is used, iteration will start at the chunk index specified by `origin`. + The default behavior is to start at the origin of the grid coordinate space. + If the `selection_shape` keyword is used, iteration will be bounded over a contiguous region + ranging from `[origin, origin + selection_shape]`, where the upper bound is exclusive as + per python indexing conventions. - def _iter_chunk_keys(self, origin: Sequence[int] | None = None) -> Iterator[str]: + Parameters + ---------- + origin: Sequence[int] | None, default=None + The origin of the selection relative to the array's chunk grid. + selection_shape: Sequence[int] | None, default=None + The shape of the selection in chunk grid coordinates. + + Yields + ------ + chunk_coords: ChunkCoords + The coordinates of each chunk in the selection. """ - Return an iterator over the storage keys of each chunk, relative to an optional origin. + return _iter_grid(self.cdata_shape, origin=origin, selection_shape=selection_shape) + + def _iter_chunk_keys( + self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None + ) -> Iterator[str]: + """ + Iterate over the storage keys of each chunk, relative to an optional origin, and optionally + limited to a contiguous region in chunk grid coordinates. + + Parameters + ---------- + origin: Sequence[int] | None, default=None + The origin of the selection relative to the array's chunk grid. + selection_shape: Sequence[int] | None, default=None + The shape of the selection in chunk grid coordinates. + + Yields + ------ + key: str + The storage key of each chunk in the selection. """ - for k in self._iter_chunk_coords(origin=origin): + # Iterate over the coordinates of chunks in chunk grid space. + for k in self._iter_chunk_coords(origin=origin, selection_shape=selection_shape): + # Encode the chunk key from the chunk coordinates. yield self.metadata.encode_chunk_key(k) - def _iter_chunk_regions(self) -> Iterator[tuple[slice, ...]]: + def _iter_chunk_regions( + self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None + ) -> Iterator[tuple[slice, ...]]: """ Iterate over the regions spanned by each chunk. + + Parameters + ---------- + origin: Sequence[int] | None, default=None + The origin of the selection relative to the array's chunk grid. + selection_shape: Sequence[int] | None, default=None + The shape of the selection in chunk grid coordinates. + + Yields + ------ + region: tuple[slice, ...] + A tuple of slice objects representing the region spanned by each chunk in the selection. """ - for cgrid_position in self._iter_chunk_coords(): + for cgrid_position in self._iter_chunk_coords( + origin=origin, selection_shape=selection_shape + ): out: tuple[slice, ...] = () for c_pos, c_shape in zip(cgrid_position, self.chunks, strict=False): start = c_pos * c_shape @@ -816,12 +874,32 @@ def nchunks(self) -> int: """ return self._async_array.nchunks - def _iter_chunks(self, origin: Sequence[int] | None = None) -> Iterator[ChunkCoords]: + def _iter_chunk_coords( + self, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None + ) -> Iterator[ChunkCoords]: """ - Produce an iterator over the coordinates of each chunk, in chunk grid space, relative - to an optional origin. + Create an iterator over the coordinates of chunks in chunk grid space. If the `origin` + keyword is used, iteration will start at the chunk index specified by `origin`. + The default behavior is to start at the origin of the grid coordinate space. + If the `selection_shape` keyword is used, iteration will be bounded over a contiguous region + ranging from `[origin, origin + selection_shape]`, where the upper bound is exclusive as + per python indexing conventions. + + Parameters + ---------- + origin: Sequence[int] | None, default=None + The origin of the selection relative to the array's chunk grid. + selection_shape: Sequence[int] | None, default=None + The shape of the selection in chunk grid coordinates. + + Yields + ------ + chunk_coords: ChunkCoords + The coordinates of each chunk in the selection. """ - yield from self._async_array._iter_chunk_coords(origin=origin) + yield from self._async_array._iter_chunk_coords( + origin=origin, selection_shape=selection_shape + ) @property def nbytes(self) -> int: @@ -830,17 +908,57 @@ def nbytes(self) -> int: """ return self._async_array.nbytes - def _iter_chunk_keys(self, origin: Sequence[int] | None = None) -> Iterator[str]: + @property + def nchunks_initialized(self) -> int: + """ + The number of chunks that have been initialized in the stored representation of this array. + """ + return self._async_array.nchunks_initialized + + def _iter_chunk_keys( + self, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None + ) -> Iterator[str]: """ - Return an iterator over the keys of each chunk, relative to an optional origin + Iterate over the storage keys of each chunk, relative to an optional origin, and optionally + limited to a contiguous region in chunk grid coordinates. + + Parameters + ---------- + origin: Sequence[int] | None, default=None + The origin of the selection relative to the array's chunk grid. + selection_shape: Sequence[int] | None, default=None + The shape of the selection in chunk grid coordinates. + + Yields + ------ + key: str + The storage key of each chunk in the selection. """ - yield from self._async_array._iter_chunk_keys(origin=origin) + yield from self._async_array._iter_chunk_keys( + origin=origin, selection_shape=selection_shape + ) - def _iter_chunk_regions(self) -> Iterator[tuple[slice, ...]]: + def _iter_chunk_regions( + self, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None + ) -> Iterator[tuple[slice, ...]]: """ Iterate over the regions spanned by each chunk. + + Parameters + ---------- + origin: Sequence[int] | None, default=None + The origin of the selection relative to the array's chunk grid. + selection_shape: Sequence[int] | None, default=None + The shape of the selection in chunk grid coordinates. + + Yields + ------ + region: tuple[slice, ...] + A tuple of slice objects representing the region spanned by each chunk in the selection. """ - yield from self._async_array._iter_chunk_regions() + yield from self._async_array._iter_chunk_regions( + origin=origin, selection_shape=selection_shape + ) def __array__( self, dtype: npt.DTypeLike | None = None, copy: bool | None = None @@ -2175,17 +2293,46 @@ def info(self) -> None: ) -def nchunks_initialized(array: Array) -> int: +def nchunks_initialized(array: AsyncArray | Array) -> int: """ Calculate the number of chunks that have been initialized, i.e. the number of chunks that have been persisted to the storage backend. + + Parameters + ---------- + array : Array + The array to inspect. + + Returns + ------- + nchunks_initialized : int + The number of chunks that have been initialized. + + See Also + -------- + chunks_initialized """ return len(chunks_initialized(array)) -def chunks_initialized(array: Array) -> tuple[str, ...]: +def chunks_initialized(array: Array | AsyncArray) -> tuple[str, ...]: """ Return the keys of the chunks that have been persisted to the storage backend. + + Parameters + ---------- + array : Array + The array to inspect. + + Returns + ------- + chunks_initialized : tuple[str, ...] + The keys of the chunks that have been initialized. + + See Also + -------- + nchunks_initialized + """ # TODO: make this compose with the underlying async iterator store_contents = list( diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index 732b3b35a0..0c5fbd52ea 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -93,23 +93,24 @@ def ceildiv(a: float, b: float) -> int: def _iter_grid( - shape: Sequence[int], + grid_shape: Sequence[int], *, origin: Sequence[int] | None = None, + selection_shape: Sequence[int] | None = None, order: _ArrayIndexingOrder = "lexicographic", ) -> Iterator[ChunkCoords]: """ - Iterate over the elements of grid of integers. - - Takes a grid shape expressed as a sequence of integers and an optional origin and - yields tuples bounded by [origin, origin + grid_shape]. + Iterate over the elements of grid of integers, with the option to restrict the domain of + iteration to those from a contiguous subregion of that grid. Parameters --------- - shape: Sequence[int] + grid_shape: Sequence[int] The size of the domain to iterate over. origin: Sequence[int] | None, default=None - The first coordinate of the domain. + The first coordinate of the domain to return. + selection_shape: Sequence[int] | None, default=None + The shape of the selection. order: Literal["lexicographic"], default="lexicographic" The linear indexing order to use. @@ -129,22 +130,37 @@ def _iter_grid( >>> tuple(iter_grid((2,3)), origin=(1,1)) ((1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)) + + >>> tuple(iter_grid((2,3)), origin=(1,1), selection_shape=(2,2)) + ((1, 1), (1, 2), (1, 3), (2, 1)) """ if origin is None: - origin_parsed = (0,) * len(shape) + origin_parsed = (0,) * len(grid_shape) else: - if len(origin) != len(shape): + if len(origin) != len(grid_shape): msg = ( "Shape and origin parameters must have the same length." - f"Got {len(shape)} elements in shape, but {len(origin)} elements in origin." + f"Got {len(grid_shape)} elements in shape, but {len(origin)} elements in origin." ) raise ValueError(msg) origin_parsed = tuple(origin) - - if order == "lexicographic": - yield from itertools.product( - *(range(o, o + s) for o, s in zip(origin_parsed, shape, strict=True)) + if selection_shape is None: + selection_shape_parsed = tuple( + g - o for o, g in zip(origin_parsed, grid_shape, strict=True) ) + else: + selection_shape_parsed = tuple(selection_shape) + if order == "lexicographic": + dimensions: tuple[range, ...] = () + for idx, (o, gs, ss) in enumerate( + zip(origin_parsed, grid_shape, selection_shape_parsed, strict=True) + ): + if o + ss > gs: + raise IndexError( + f"Invalid selection shape ({selection_shape}) for origin ({origin}) and grid shape ({grid_shape}) at axis {idx}." + ) + dimensions += (range(o, o + ss),) + yield from itertools.product(*(dimensions)) else: msg = f"Indexing order {order} is not supported at this time." # type: ignore[unreachable] diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index b3362c52b0..d0f32a4470 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -1,12 +1,15 @@ import pickle +from itertools import accumulate from typing import Literal import numpy as np import pytest +from src.zarr.core.array import chunks_initialized from zarr import Array, AsyncArray, Group from zarr.core.buffer.cpu import NDBuffer from zarr.core.common import ZarrFormat +from zarr.core.sync import sync from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.store import LocalStore, MemoryStore from zarr.store.common import StorePath @@ -232,3 +235,55 @@ def test_serializable_sync_array(store: LocalStore, zarr_format: ZarrFormat) -> assert actual == expected np.testing.assert_array_equal(actual[:], expected[:]) + + +@pytest.mark.parametrize("test_cls", [Array, AsyncArray]) +def test_nchunks_initialized(test_cls: type[Array] | type[AsyncArray]) -> None: + """ + Test that nchunks_initialized accurately returns the number of stored chunks. + """ + store = MemoryStore({}, mode="w") + arr = Array.create(store, shape=(100,), chunks=(10,), dtype="i4") + + # write chunks one at a time + for idx, region in enumerate(arr._iter_chunk_regions()): + arr[region] = 1 + expected = idx + 1 + if test_cls == Array: + observed = arr.nchunks_initialized + else: + observed = arr._async_array.nchunks_initialized + assert observed == expected + + # delete chunks + for idx, key in enumerate(arr._iter_chunk_keys()): + sync(arr.store_path.store.delete(key)) + if test_cls == Array: + observed = arr.nchunks_initialized + else: + observed = arr._async_array.nchunks_initialized + expected = arr.nchunks - idx - 1 + assert observed == expected + + +@pytest.mark.parametrize("test_cls", [Array, AsyncArray]) +def test_chunks_initialized(test_cls: type[Array] | type[AsyncArray]) -> None: + """ + Test that chunks_initialized accurately returns the keys of stored chunks. + """ + store = MemoryStore({}, mode="w") + arr = Array.create(store, shape=(100,), chunks=(10,), dtype="i4") + + chunks_accumulated = tuple( + accumulate(tuple(map(lambda v: tuple(v.split(" ")), arr._iter_chunk_keys()))) + ) + for keys, region in zip(chunks_accumulated, arr._iter_chunk_regions(), strict=False): + arr[region] = 1 + + if test_cls == Array: + observed = sorted(chunks_initialized(arr)) + else: + observed = sorted(chunks_initialized(arr._async_array)) + + expected = sorted(keys) + assert observed == expected diff --git a/tests/v3/test_indexing.py b/tests/v3/test_indexing.py index 8b509f93d1..90d34f16b0 100644 --- a/tests/v3/test_indexing.py +++ b/tests/v3/test_indexing.py @@ -1,5 +1,6 @@ from __future__ import annotations +import itertools from collections import Counter from typing import TYPE_CHECKING, Any from uuid import uuid4 @@ -16,6 +17,7 @@ CoordinateSelection, OrthogonalSelection, Selection, + _iter_grid, make_slice_selection, normalize_integer_selection, oindex, @@ -1861,3 +1863,53 @@ def test_orthogonal_bool_indexing_like_numpy_ix( # note: in python 3.10 z[*selection] is not valid unpacking syntax actual = z[(*selection,)] assert_array_equal(expected, actual, err_msg=f"{selection=}") + + +@pytest.mark.parametrize("ndim", [1, 2, 3]) +@pytest.mark.parametrize("origin_0d", [None, (0,), (1,)]) +@pytest.mark.parametrize("selection_shape_0d", [None, (2,), (3,)]) +def test_iter_grid( + ndim: int, origin_0d: tuple[int] | None, selection_shape_0d: tuple[int] | None +) -> None: + """ + Test that iter_grid works as expected for 1, 2, and 3 dimensions. + """ + grid_shape = (5,) * ndim + + if origin_0d is not None: + origin_kwarg = origin_0d * ndim + origin = origin_kwarg + else: + origin_kwarg = None + origin = (0,) * ndim + + if selection_shape_0d is not None: + selection_shape_kwarg = selection_shape_0d * ndim + selection_shape = selection_shape_kwarg + else: + selection_shape_kwarg = None + selection_shape = tuple(gs - o for gs, o in zip(grid_shape, origin, strict=False)) + + observed = tuple( + _iter_grid(grid_shape, origin=origin_kwarg, selection_shape=selection_shape_kwarg) + ) + + # generate a numpy array of indices, and index it + coord_array = np.array(list(itertools.product(*[range(s) for s in grid_shape]))).reshape( + (*grid_shape, ndim) + ) + coord_array_indexed = coord_array[ + tuple(slice(o, o + s, 1) for o, s in zip(origin, selection_shape, strict=False)) + + (range(ndim),) + ] + + expected = tuple(map(tuple, coord_array_indexed.reshape(-1, ndim).tolist())) + assert observed == expected + + +def test_iter_grid_invalid() -> None: + """ + Ensure that a selection_shape that exceeds the grid_shape + origin produces an indexing error. + """ + with pytest.raises(IndexError): + list(_iter_grid((5,), origin=(0,), selection_shape=(10,))) From e60cbe0b3217165530e04b060a5a16ff968d8744 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 24 Sep 2024 15:09:38 +0200 Subject: [PATCH 19/22] add nchunks test --- tests/v3/test_array.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index d0f32a4470..37e47e7b8b 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -5,10 +5,11 @@ import numpy as np import pytest -from src.zarr.core.array import chunks_initialized from zarr import Array, AsyncArray, Group +from zarr.core.array import chunks_initialized from zarr.core.buffer.cpu import NDBuffer from zarr.core.common import ZarrFormat +from zarr.core.indexing import ceildiv from zarr.core.sync import sync from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.store import LocalStore, MemoryStore @@ -237,6 +238,23 @@ def test_serializable_sync_array(store: LocalStore, zarr_format: ZarrFormat) -> np.testing.assert_array_equal(actual[:], expected[:]) +@pytest.mark.parametrize("test_cls", [Array, AsyncArray]) +@pytest.mark.parametrize("nchunks", (2, 5, 10)) +def test_nchunks(test_cls: type[Array] | type[AsyncArray], nchunks: int) -> None: + """ + Test that nchunks returns the number of chunks defined for the array. + """ + store = MemoryStore({}, mode="w") + shape = 100 + arr = Array.create(store, shape=(shape,), chunks=(ceildiv(shape, nchunks),), dtype="i4") + expected = nchunks + if test_cls == Array: + observed = arr.nchunks + else: + observed = arr._async_array.nchunks + assert observed == expected + + @pytest.mark.parametrize("test_cls", [Array, AsyncArray]) def test_nchunks_initialized(test_cls: type[Array] | type[AsyncArray]) -> None: """ From 5c54449442fafa52f13a35596e8b94c56344105a Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 24 Sep 2024 15:17:22 +0200 Subject: [PATCH 20/22] fix docstrings --- src/zarr/abc/store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 2890e93aa0..805df63e09 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -108,7 +108,7 @@ async def get( Parameters ---------- key : str - byte_range : tuple[int, Optional[int]], optional + byte_range : tuple[int | None, int | None], optional Returns ------- @@ -126,7 +126,7 @@ async def get_partial_values( Parameters ---------- - key_ranges : list[tuple[str, tuple[int | None, int | None]]] + key_ranges : Iterable[tuple[str, tuple[int | None, int | None]]] Ordered set of key, range pairs, a key may occur multiple times with different ranges Returns From e8598c6504dd14d70a94e28dd5c8b61495b4d60f Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 24 Sep 2024 15:25:20 +0200 Subject: [PATCH 21/22] fix docstring --- src/zarr/core/indexing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index 0c5fbd52ea..50af5b30ec 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -101,10 +101,10 @@ def _iter_grid( ) -> Iterator[ChunkCoords]: """ Iterate over the elements of grid of integers, with the option to restrict the domain of - iteration to those from a contiguous subregion of that grid. + iteration to a contiguous subregion of that grid. Parameters - --------- + ---------- grid_shape: Sequence[int] The size of the domain to iterate over. origin: Sequence[int] | None, default=None From 768ab43cb4cfb01f69c9fe38aba943a4cb938e17 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 24 Sep 2024 17:21:24 +0200 Subject: [PATCH 22/22] revert unnecessary changes to project config --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ecc184f1db..63a58ac795 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ 'asciitree', 'numpy>=1.24', 'fasteners', - 'numcodecs>=0.13.0', + 'numcodecs>=0.10.0', 'fsspec>2024', 'crc32c', 'typing_extensions', @@ -273,7 +273,6 @@ filterwarnings = [ "error:::zarr.*", "ignore:PY_SSIZE_T_CLEAN will be required.*:DeprecationWarning", "ignore:The loop argument is deprecated since Python 3.8.*:DeprecationWarning", - "ignore:.*may be removed in an early zarr-python v3 release.:DeprecationWarning", "ignore:Creating a zarr.buffer.gpu.*:UserWarning", "ignore:Duplicate name:UserWarning", # from ZipFile ]