Skip to content

make shardingcodec pickleable #2011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,22 @@ def __init__(
object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec))
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard))

# todo: typedict return type
def __getstate__(self) -> dict[str, Any]:
return self.to_dict()

def __setstate__(self, state: dict[str, Any]) -> None:
config = state["configuration"]
object.__setattr__(self, "chunk_shape", parse_shapelike(config["chunk_shape"]))
object.__setattr__(self, "codecs", parse_codecs(config["codecs"]))
object.__setattr__(self, "index_codecs", parse_codecs(config["index_codecs"]))
object.__setattr__(self, "index_location", parse_index_location(config["index_location"]))

# Use instance-local lru_cache to avoid memory leaks
object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec))
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard))

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
_, configuration_parsed = parse_named_configuration(data, "sharding_indexed")
Expand Down
35 changes: 18 additions & 17 deletions src/zarr/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,24 +1220,25 @@ def make_slice_selection(selection: Any) -> list[slice]:
return ls


def morton_order_iter(chunk_shape: ChunkCoords) -> Iterator[ChunkCoords]:
def decode_morton(z: int, chunk_shape: ChunkCoords) -> ChunkCoords:
# Inspired by compressed morton code as implemented in Neuroglancer
# https://github.com/google/neuroglancer/blob/master/src/neuroglancer/datasource/precomputed/volume.md#compressed-morton-code
bits = tuple(math.ceil(math.log2(c)) for c in chunk_shape)
max_coords_bits = max(*bits)
input_bit = 0
input_value = z
out = [0 for _ in range(len(chunk_shape))]

for coord_bit in range(max_coords_bits):
for dim in range(len(chunk_shape)):
if coord_bit < bits[dim]:
bit = (input_value >> input_bit) & 1
out[dim] |= bit << coord_bit
input_bit += 1
return tuple(out)
def decode_morton(z: int, chunk_shape: ChunkCoords) -> ChunkCoords:
# Inspired by compressed morton code as implemented in Neuroglancer
# https://github.com/google/neuroglancer/blob/master/src/neuroglancer/datasource/precomputed/volume.md#compressed-morton-code
bits = tuple(math.ceil(math.log2(c)) for c in chunk_shape)
max_coords_bits = max(bits)
input_bit = 0
input_value = z
out = [0] * len(chunk_shape)

for coord_bit in range(max_coords_bits):
for dim in range(len(chunk_shape)):
if coord_bit < bits[dim]:
bit = (input_value >> input_bit) & 1
out[dim] |= bit << coord_bit
input_bit += 1
return tuple(out)


def morton_order_iter(chunk_shape: ChunkCoords) -> Iterator[ChunkCoords]:
for i in range(product(chunk_shape)):
yield decode_morton(i, chunk_shape)

Expand Down
41 changes: 31 additions & 10 deletions tests/v3/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
from types import ModuleType
from typing import TYPE_CHECKING

from zarr.common import ZarrFormat
from _pytest.compat import LEGACY_PATH

from zarr.abc.store import Store
from zarr.common import ChunkCoords, MemoryOrder, ZarrFormat
from zarr.group import AsyncGroup

if TYPE_CHECKING:
from typing import Any, Literal
import pathlib
from dataclasses import dataclass, field

import numpy as np
import pytest

from zarr.store import LocalStore, MemoryStore, StorePath
Expand All @@ -26,40 +30,40 @@ def parse_store(
if store == "memory":
return MemoryStore(mode="w")
if store == "remote":
return RemoteStore(mode="w")
return RemoteStore(url=path, mode="w")
raise AssertionError


@pytest.fixture(params=[str, pathlib.Path])
def path_type(request):
def path_type(request: pytest.FixtureRequest) -> Any:
return request.param


# todo: harmonize this with local_store fixture
@pytest.fixture
def store_path(tmpdir):
def store_path(tmpdir: LEGACY_PATH) -> StorePath:
store = LocalStore(str(tmpdir), mode="w")
p = StorePath(store)
return p


@pytest.fixture(scope="function")
def local_store(tmpdir):
def local_store(tmpdir: LEGACY_PATH) -> LocalStore:
return LocalStore(str(tmpdir), mode="w")


@pytest.fixture(scope="function")
def remote_store():
return RemoteStore(mode="w")
def remote_store(url: str) -> RemoteStore:
return RemoteStore(url, mode="w")


@pytest.fixture(scope="function")
def memory_store():
def memory_store() -> MemoryStore:
return MemoryStore(mode="w")


@pytest.fixture(scope="function")
def store(request: str, tmpdir):
def store(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> Store:
param = request.param
return parse_store(param, str(tmpdir))

Expand All @@ -72,7 +76,7 @@ class AsyncGroupRequest:


@pytest.fixture(scope="function")
async def async_group(request: pytest.FixtureRequest, tmpdir) -> AsyncGroup:
async def async_group(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> AsyncGroup:
param: AsyncGroupRequest = request.param

store = parse_store(param.store, str(tmpdir))
Expand All @@ -90,3 +94,20 @@ def xp(request: pytest.FixtureRequest) -> Iterator[ModuleType]:
"""Fixture to parametrize over numpy-like libraries"""

yield pytest.importorskip(request.param)


@dataclass
class ArrayRequest:
shape: ChunkCoords
dtype: str
order: MemoryOrder


@pytest.fixture
def array_fixture(request: pytest.FixtureRequest) -> np.ndarray:
array_request: ArrayRequest = request.param
return (
np.arange(np.prod(array_request.shape))
.reshape(array_request.shape, order=array_request.order)
.astype(array_request.dtype)
)
Loading