Skip to content

Generalize stateful store test #2202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def v2_dtypes() -> st.SearchStrategy[np.dtype]:
)
array_names = node_names
attrs = st.none() | st.dictionaries(_attr_keys, _attr_values)
paths = st.lists(node_names, min_size=1).map("/".join) | st.just("/")
keys = st.lists(node_names, min_size=1).map(lambda x: "/".join(x))
paths = st.just("/") | keys
stores = st.builds(MemoryStore, st.just({}), mode=st.just("w"))
compressors = st.sampled_from([None, "default"])
zarr_formats: st.SearchStrategy[Literal[2, 3]] = st.sampled_from([2, 3])
Expand Down Expand Up @@ -171,7 +172,9 @@ def basic_indices(draw: st.DrawFn, *, shape: tuple[int], **kwargs): # type: ign
)


def key_ranges(keys: SearchStrategy = node_names) -> SearchStrategy[list]:
def key_ranges(
keys: SearchStrategy = node_names, max_size: int | None = None
) -> SearchStrategy[list[int]]:
"""
Function to generate key_ranges strategy for get_partial_values()
returns list strategy w/ form::
Expand All @@ -180,7 +183,8 @@ def key_ranges(keys: SearchStrategy = node_names) -> SearchStrategy[list]:
(key, (range_start, range_step)),...]
"""
byte_ranges = st.tuples(
st.none() | st.integers(min_value=0), st.none() | st.integers(min_value=0)
st.none() | st.integers(min_value=0, max_value=max_size),
st.none() | st.integers(min_value=0, max_value=max_size),
)
key_tuple = st.tuples(keys, byte_ranges)
return st.lists(key_tuple, min_size=1, max_size=10)
11 changes: 10 additions & 1 deletion tests/v3/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from hypothesis import HealthCheck, Verbosity, settings

from zarr import AsyncGroup, config
from zarr.abc.store import Store
from zarr.core.sync import sync
from zarr.store import LocalStore, MemoryStore, StorePath, ZipStore
from zarr.store.remote import RemoteStore

Expand All @@ -19,7 +21,6 @@

from _pytest.compat import LEGACY_PATH

from zarr.abc.store import Store
from zarr.core.common import ChunkCoords, MemoryOrder, ZarrFormat


Expand Down Expand Up @@ -75,6 +76,14 @@ async def store(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> Store:
return await parse_store(param, str(tmpdir))


@pytest.fixture(params=["local", "memory", "zip"])
def sync_store(request: pytest.FixtureRequest, tmp_path: LEGACY_PATH) -> Store:
result = sync(parse_store(request.param, str(tmp_path)))
if not isinstance(result, Store):
raise TypeError("Wrong store class returned by test fixture! got " + result + " instead")
return result


@dataclass
class AsyncGroupRequest:
zarr_format: ZarrFormat
Expand Down
63 changes: 43 additions & 20 deletions tests/v3/test_store/test_stateful_store.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
# Stateful tests for arbitrary Zarr stores.


import hypothesis.strategies as st
import pytest
from hypothesis import assume, note
from hypothesis.stateful import (
RuleBasedStateMachine,
Settings,
initialize,
invariant,
precondition,
rule,
run_state_machine_as_test,
)
from hypothesis.strategies import DataObject

import zarr
from zarr.abc.store import AccessMode, Store
from zarr.core.buffer import BufferPrototype, cpu, default_buffer_prototype
from zarr.store import MemoryStore
from zarr.testing.strategies import key_ranges, paths
from zarr.store import LocalStore, ZipStore
from zarr.testing.strategies import key_ranges
from zarr.testing.strategies import keys as zarr_keys

MAX_BINARY_SIZE = 100


class SyncStoreWrapper(zarr.core.sync.SyncMixin):
Expand Down Expand Up @@ -99,13 +104,17 @@ class ZarrStoreStateMachine(RuleBasedStateMachine):
https://hypothesis.readthedocs.io/en/latest/stateful.html
"""

def __init__(self) -> None:
def __init__(self, store: Store) -> None:
super().__init__()
self.model: dict[str, bytes] = {}
self.store = SyncStoreWrapper(MemoryStore(mode="w"))
self.store = SyncStoreWrapper(store)
self.prototype = default_buffer_prototype()

@rule(key=paths, data=st.binary(min_size=0, max_size=100))
@initialize()
def init_store(self):
self.store.clear()

@rule(key=zarr_keys, data=st.binary(min_size=0, max_size=MAX_BINARY_SIZE))
def set(self, key: str, data: DataObject) -> None:
note(f"(set) Setting {key!r} with {data}")
assert not self.store.mode.readonly
Expand All @@ -114,7 +123,7 @@ def set(self, key: str, data: DataObject) -> None:
self.model[key] = data_buf

@precondition(lambda self: len(self.model.keys()) > 0)
@rule(key=paths, data=st.data())
@rule(key=zarr_keys, data=st.data())
def get(self, key: str, data: DataObject) -> None:
key = data.draw(
st.sampled_from(sorted(self.model.keys()))
Expand All @@ -124,16 +133,18 @@ def get(self, key: str, data: DataObject) -> None:
# to bytes here necessary because data_buf set to model in set()
assert self.model[key].to_bytes() == (store_value.to_bytes())

@rule(key=paths, data=st.data())
def get_invalid_keys(self, key: str, data: DataObject) -> None:
@rule(key=zarr_keys, data=st.data())
def get_invalid_zarr_keys(self, key: str, data: DataObject) -> None:
note("(get_invalid)")
assume(key not in self.model)
assert self.store.get(key, self.prototype) is None

@precondition(lambda self: len(self.model.keys()) > 0)
@rule(data=st.data())
def get_partial_values(self, data: DataObject) -> None:
key_range = data.draw(key_ranges(keys=st.sampled_from(sorted(self.model.keys()))))
key_range = data.draw(
key_ranges(keys=st.sampled_from(sorted(self.model.keys())), max_size=MAX_BINARY_SIZE)
)
note(f"(get partial) {key_range=}")
obs_maybe = self.store.get_partial_values(key_range, self.prototype)
observed = []
Expand Down Expand Up @@ -173,16 +184,20 @@ def clear(self) -> None:
self.store.clear()
self.model.clear()

assert self.store.empty()

assert len(self.model.keys()) == len(list(self.store.list())) == 0

@rule()
# Local store can be non-empty when there are subdirectories but no files
@precondition(lambda self: not isinstance(self.store.store, LocalStore))
def empty(self) -> None:
note("(empty)")

# make sure they either both are or both aren't empty (same state)
assert self.store.empty() == (not self.model)

@rule(key=paths)
@rule(key=zarr_keys)
def exists(self, key: str) -> None:
note("(exists)")

Expand All @@ -191,9 +206,9 @@ def exists(self, key: str) -> None:
@invariant()
def check_paths_equal(self) -> None:
note("Checking that paths are equal")
paths = list(self.store.list())
paths = sorted(self.store.list())

assert list(self.model.keys()) == paths
assert sorted(self.model.keys()) == paths

@invariant()
def check_vals_equal(self) -> None:
Expand All @@ -203,24 +218,32 @@ def check_vals_equal(self) -> None:
assert val.to_bytes() == store_item

@invariant()
def check_num_keys_equal(self) -> None:
note("check num keys equal")
def check_num_zarr_keys_equal(self) -> None:
note("check num zarr_keys equal")

assert len(self.model) == len(list(self.store.list()))

@invariant()
def check_keys(self) -> None:
def check_zarr_keys(self) -> None:
keys = list(self.store.list())

if len(keys) == 0:
if not keys:
assert self.store.empty() is True

elif len(keys) != 0:
else:
assert self.store.empty() is False

for key in keys:
assert self.store.exists(key) is True
note("checking keys / exists / empty")


StatefulStoreTest = ZarrStoreStateMachine.TestCase
def test_zarr_hierarchy(sync_store: Store) -> None:
def mk_test_instance_sync():
return ZarrStoreStateMachine(sync_store)

if isinstance(sync_store, ZipStore):
pytest.skip(reason="ZipStore does not support delete")
if isinstance(sync_store, LocalStore):
pytest.skip(reason="This test has errors")
run_state_machine_as_test(mk_test_instance_sync, settings=Settings(report_multiple_bugs=True))