Skip to content

FSStore: use ensure_bytes() #1285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions zarr/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
from zarr.util import (buffer_size, json_loads, nolock, normalize_chunks,
normalize_dimension_separator,
normalize_dtype, normalize_fill_value, normalize_order,
normalize_shape, normalize_storage_path, retry_call
)
normalize_shape, normalize_storage_path, retry_call,
ensure_contiguous_ndarray_or_bytes)

from zarr._storage.absstore import ABSStore # noqa: F401
from zarr._storage.store import (_get_hierarchy_metadata, # noqa: F401
Expand Down Expand Up @@ -1395,13 +1395,19 @@ def __getitem__(self, key):
def setitems(self, values):
if self.mode == 'r':
raise ReadOnlyError()
values = {self._normalize_key(key): val for key, val in values.items()}

# Normalize keys and make sure the values are bytes
values = {
self._normalize_key(key): ensure_contiguous_ndarray_or_bytes(val)
for key, val in values.items()
}
Comment on lines +1400 to +1403
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something we may consider is delaying the instantiation of the dict so these copies occur as these values are requested. This can be a follow-on though

self.map.setitems(values)

def __setitem__(self, key, value):
if self.mode == 'r':
raise ReadOnlyError()
key = self._normalize_key(key)
value = ensure_contiguous_ndarray_or_bytes(value)
path = self.dir_path(key)
try:
if self.fs.isdir(path):
Expand Down
25 changes: 25 additions & 0 deletions zarr/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from numpy.testing import assert_array_almost_equal, assert_array_equal
from pkg_resources import parse_version

import zarr
from zarr._storage.store import (
v3_api_available,
)
Expand Down Expand Up @@ -3409,3 +3410,27 @@ def test_array_mismatched_store_versions():
Array(store_v3, path='dataset', read_only=False, chunk_store=chunk_store_v2)
with pytest.raises(ValueError):
Array(store_v2, path='dataset', read_only=False, chunk_store=chunk_store_v3)


@pytest.mark.skipif(have_fsspec is False, reason="needs fsspec")
def test_issue_1279(tmpdir):
"""See <https://github.com/zarr-developers/zarr-python/issues/1279>"""

data = np.arange(25).reshape((5, 5))
ds = zarr.create(
shape=data.shape,
chunks=(5, 5),
dtype=data.dtype,
compressor=(None),
store=FSStore(url=str(tmpdir), mode="a"),
order="F",
)

ds[:] = data

ds_reopened = zarr.open_array(
store=FSStore(url=str(tmpdir), mode="r")
)

written_data = ds_reopened[:]
assert_array_equal(data, written_data)
36 changes: 33 additions & 3 deletions zarr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,22 @@
from textwrap import TextWrapper
import mmap
import time
from typing import Any, Callable, Dict, Optional, Tuple, Union

import numpy as np
from asciitree import BoxStyle, LeftAligned
from asciitree.traversal import Traversal
from collections.abc import Iterable
from numcodecs.compat import ensure_text, ensure_ndarray_like
from numcodecs.compat import (
ensure_text,
ensure_ndarray_like,
ensure_bytes,
ensure_contiguous_ndarray_like
)
from numcodecs.ndarray_like import NDArrayLike
from numcodecs.registry import codec_registry
from numcodecs.blosc import cbuffer_sizes, cbuffer_metainfo

from typing import Any, Callable, Dict, Optional, Tuple, Union


def flatten(arg: Iterable) -> Iterable:
for element in arg:
Expand Down Expand Up @@ -696,3 +701,28 @@ def all_equal(value: Any, array: Any):
# using == raises warnings from numpy deprecated pattern, but
# using np.equal() raises type errors for structured dtypes...
return np.all(value == array)


def ensure_contiguous_ndarray_or_bytes(buf) -> Union[NDArrayLike, bytes]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to move this to numcodecs.compat, but that can be follow-on work

"""Convenience function to coerce `buf` to ndarray-like array or bytes.

First check if `buf` can be zero-copy converted to a contiguous array.
If not, `buf` will be copied to a newly allocated `bytes` object.

Parameters
----------
buf : ndarray-like, array-like, or bytes-like
A numpy array like object such as numpy.ndarray, cupy.ndarray, or
any object exporting a buffer interface.

Returns
-------
arr : NDArrayLike or bytes
A ndarray-like or bytes object
"""

try:
return ensure_contiguous_ndarray_like(buf)
except TypeError:
# An error is raised if `buf` couldn't be zero-copy converted
return ensure_bytes(buf)