Skip to content

Further validation of v3 fill values #2216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ filterwarnings = [
"ignore:PY_SSIZE_T_CLEAN will be required.*:DeprecationWarning",
"ignore:The loop argument is deprecated since Python 3.8.*:DeprecationWarning",
"ignore:Creating a zarr.buffer.gpu.*:UserWarning",
"ignore:Duplicate name:UserWarning", # from ZipFile
]
markers = [
"gpu: mark a test as requiring CuPy and GPU"
Expand Down
5 changes: 2 additions & 3 deletions src/zarr/core/array_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal

from zarr.core.common import parse_dtype, parse_fill_value, parse_order, parse_shapelike
from zarr.core.common import parse_fill_value, parse_order, parse_shapelike

if TYPE_CHECKING:
import numpy as np
Expand All @@ -29,12 +29,11 @@ def __init__(
prototype: BufferPrototype,
) -> None:
shape_parsed = parse_shapelike(shape)
dtype_parsed = parse_dtype(dtype)
fill_value_parsed = parse_fill_value(fill_value)
order_parsed = parse_order(order)

object.__setattr__(self, "shape", shape_parsed)
object.__setattr__(self, "dtype", dtype_parsed)
object.__setattr__(self, "dtype", dtype)
object.__setattr__(self, "fill_value", fill_value_parsed)
object.__setattr__(self, "order", order_parsed)
object.__setattr__(self, "prototype", prototype)
Expand Down
7 changes: 0 additions & 7 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterator

import numpy as np
import numpy.typing as npt

ZARR_JSON = "zarr.json"
ZARRAY_JSON = ".zarray"
Expand Down Expand Up @@ -155,11 +153,6 @@ def parse_shapelike(data: int | Iterable[int]) -> tuple[int, ...]:
return data_tuple


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
# todo: real validation
return np.dtype(data)


def parse_fill_value(data: Any) -> Any:
# todo: real validation
return data
Expand Down
29 changes: 25 additions & 4 deletions src/zarr/core/metadata/v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections.abc import Iterable
from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -21,7 +22,7 @@
from zarr.core.array_spec import ArraySpec
from zarr.core.chunk_grids import RegularChunkGrid
from zarr.core.chunk_key_encodings import parse_separator
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_shapelike
from zarr.core.config import config, parse_indexing_order
from zarr.core.metadata.common import ArrayMetadata, parse_attributes

Expand Down Expand Up @@ -100,9 +101,24 @@ def _json_convert(
else:
return o.descr
if np.isscalar(o):
# convert numpy scalar to python type, and pass
# python types through
return getattr(o, "item", lambda: o)()
out: Any
if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"):
# https://github.com/zarr-developers/zarr-python/issues/2119
# `.item()` on a datetime type might or might not return an
# integer, depending on the value.
# Explicitly cast to an int first, and then grab .item()
out = o.view("i8").item()
else:
# convert numpy scalar to python type, and pass
# python types through
out = getattr(o, "item", lambda: o)()
if isinstance(out, complex):
# python complex types are not JSON serializable, so we use the
# serialization defined in the zarr v3 spec
return [out.real, out.imag]
return out
if isinstance(o, Enum):
return o.name
raise TypeError

zarray_dict = self.to_dict()
Expand Down Expand Up @@ -157,6 +173,11 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
return replace(self, attributes=attributes)


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
# todo: real validation
return np.dtype(data)


def parse_zarr_format(data: object) -> Literal[2]:
if data == 2:
return 2
Expand Down
44 changes: 42 additions & 2 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from zarr.core.buffer import default_buffer_prototype
from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid
from zarr.core.chunk_key_encodings import ChunkKeyEncoding
from zarr.core.common import ZARR_JSON, parse_dtype, parse_named_configuration, parse_shapelike
from zarr.core.common import ZARR_JSON, parse_named_configuration, parse_shapelike
from zarr.core.config import config
from zarr.core.metadata.common import ArrayMetadata, parse_attributes
from zarr.registry import get_codec_class
Expand Down Expand Up @@ -215,6 +215,10 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
# check that the node_type attribute is correct
_ = parse_node_type_array(_data.pop("node_type"))

# check that the data_type attribute is valid
if _data["data_type"] not in DataType:
raise ValueError(f"Invalid V3 data_type: {_data['data_type']}")

# dimension_names key is optional, normalize missing to `None`
_data["dimension_names"] = _data.pop("dimension_names", None)
# attributes key is optional, normalize missing to `None`
Expand Down Expand Up @@ -328,7 +332,17 @@ def parse_fill_value(
raise ValueError(msg)
msg = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype}."
raise TypeError(msg)
return dtype.type(fill_value) # type: ignore[arg-type]

# Cast the fill_value to the given dtype
try:
casted_value = np.dtype(dtype).type(fill_value)
except (ValueError, OverflowError, TypeError) as e:
raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") from e
# Check if the value is still representable by the dtype
if fill_value != casted_value:
raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}")

return casted_value


# For type checking
Expand All @@ -345,8 +359,11 @@ class DataType(Enum):
uint16 = "uint16"
uint32 = "uint32"
uint64 = "uint64"
float16 = "float16"
float32 = "float32"
float64 = "float64"
complex64 = "complex64"
complex128 = "complex128"

@property
def byte_count(self) -> int:
Expand All @@ -360,8 +377,11 @@ def byte_count(self) -> int:
DataType.uint16: 2,
DataType.uint32: 4,
DataType.uint64: 8,
DataType.float16: 2,
DataType.float32: 4,
DataType.float64: 8,
DataType.complex64: 8,
DataType.complex128: 16,
}
return data_type_byte_counts[self]

Expand All @@ -381,8 +401,11 @@ def to_numpy_shortname(self) -> str:
DataType.uint16: "u2",
DataType.uint32: "u4",
DataType.uint64: "u8",
DataType.float16: "f2",
DataType.float32: "f4",
DataType.float64: "f8",
DataType.complex64: "c8",
DataType.complex128: "c16",
}
return data_type_to_numpy[self]

Expand All @@ -399,7 +422,24 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
"<u2": "uint16",
"<u4": "uint32",
"<u8": "uint64",
"<f2": "float16",
"<f4": "float32",
"<f8": "float64",
"<c8": "complex64",
"<c16": "complex128",
}
return DataType[dtype_to_data_type[dtype.str]]


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
try:
dtype = np.dtype(data)
except TypeError as e:
raise ValueError(f"Invalid V3 data_type: {data}") from e
# check that this is a valid v3 data_type
try:
_ = DataType.from_dtype(dtype)
except KeyError as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e

return dtype
9 changes: 7 additions & 2 deletions src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@
paths = st.lists(node_names, min_size=1).map(lambda x: "/".join(x)) | st.just("/")
np_arrays = npst.arrays(
# TODO: re-enable timedeltas once they are supported
dtype=npst.scalar_dtypes().filter(lambda x: x.kind != "m"),
dtype=npst.scalar_dtypes().filter(
lambda x: (x.kind not in ["m", "M"]) and (x.byteorder not in [">"])
),
shape=npst.array_shapes(max_dims=4),
)
stores = st.builds(MemoryStore, st.just({}), mode=st.just("w"))
compressors = st.sampled_from([None, "default"])
format = st.sampled_from([2, 3])


@st.composite # type: ignore[misc]
Expand Down Expand Up @@ -69,12 +72,14 @@ def arrays(
paths: st.SearchStrategy[None | str] = paths,
array_names: st.SearchStrategy = array_names,
attrs: st.SearchStrategy = attrs,
format: st.SearchStrategy = format,
) -> Array:
store = draw(stores)
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
path = draw(paths)
name = draw(array_names)
attributes = draw(attrs)
zarr_format = draw(format)
# compressor = draw(compressors)

# TODO: clean this up
Expand All @@ -99,7 +104,7 @@ def arrays(
expected_attrs = {} if attributes is None else attributes

array_path = path + ("/" if not path.endswith("/") else "") + name
root = Group.create(store)
root = Group.create(store, zarr_format=zarr_format)
fill_value_args: tuple[Any, ...] = tuple()
if nparray.dtype.kind == "M":
m = re.search(r"\[(.+)\]", nparray.dtype.str)
Expand Down
68 changes: 55 additions & 13 deletions tests/v3/test_metadata/test_v3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

import json
import re
from typing import TYPE_CHECKING, Literal

from zarr.codecs.bytes import BytesCodec
from zarr.core.buffer import default_buffer_prototype
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding
from zarr.core.metadata.v3 import ArrayV3Metadata

Expand All @@ -19,7 +17,12 @@
import numpy as np
import pytest

from zarr.core.metadata.v3 import parse_dimension_names, parse_fill_value, parse_zarr_format
from zarr.core.metadata.v3 import (
parse_dimension_names,
parse_dtype,
parse_fill_value,
parse_zarr_format,
)

bool_dtypes = ("bool",)

Expand Down Expand Up @@ -234,22 +237,61 @@ def test_metadata_to_dict(
assert observed == expected


@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897])
@pytest.mark.parametrize("precision", ["ns", "D"])
async def test_datetime_metadata(fill_value: int, precision: str) -> None:
# @pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897])
# @pytest.mark.parametrize("precision", ["ns", "D"])
# async def test_datetime_metadata(fill_value: int, precision: str) -> None:
# metadata_dict = {
# "zarr_format": 3,
# "node_type": "array",
# "shape": (1,),
# "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
# "data_type": f"<M8[{precision}]",
# "chunk_key_encoding": {"name": "default", "separator": "."},
# "codecs": (),
# "fill_value": np.datetime64(fill_value, precision),
# }
# metadata = ArrayV3Metadata.from_dict(metadata_dict)
# # ensure there isn't a TypeError here.
# d = metadata.to_buffer_dict(default_buffer_prototype())

# result = json.loads(d["zarr.json"].to_bytes())
# assert result["fill_value"] == fill_value


async def test_invalid_dtype_raises() -> None:
metadata_dict = {
"zarr_format": 3,
"node_type": "array",
"shape": (1,),
"chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
"data_type": f"<M8[{precision}]",
"data_type": "<M8[ns]",
"chunk_key_encoding": {"name": "default", "separator": "."},
"codecs": (),
"fill_value": np.datetime64(fill_value, precision),
"fill_value": np.datetime64(0, "ns"),
}
metadata = ArrayV3Metadata.from_dict(metadata_dict)
# ensure there isn't a TypeError here.
d = metadata.to_buffer_dict(default_buffer_prototype())
with pytest.raises(ValueError, match=r"Invalid V3 data_type"):
ArrayV3Metadata.from_dict(metadata_dict)


@pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()])
def test_parse_invalid_dtype_raises(data):
with pytest.raises(ValueError, match=r"Invalid V3 data_type"):
parse_dtype(data)

result = json.loads(d["zarr.json"].to_bytes())
assert result["fill_value"] == fill_value

@pytest.mark.parametrize(
"data_type,fill_value", [("uint8", -1), ("int32", 22.5), ("float32", "foo")]
)
async def test_invalid_fill_value_raises(data_type: str, fill_value: int | float) -> None:
metadata_dict = {
"zarr_format": 3,
"node_type": "array",
"shape": (1,),
"chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
"data_type": data_type,
"chunk_key_encoding": {"name": "default", "separator": "."},
"codecs": (),
"fill_value": fill_value, # this is not a valid fill value for uint8
}
with pytest.raises(ValueError, match=rf"fill value .* is not valid for dtype {data_type}"):
ArrayV3Metadata.from_dict(metadata_dict)
Loading