Skip to content

Special case str dtype in array creation #2323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
4 changes: 3 additions & 1 deletion src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ShapeLike,
ZarrFormat,
concurrent_map,
parse_dtype,
parse_shapelike,
product,
)
Expand Down Expand Up @@ -226,7 +227,8 @@ async def create(
if chunks is not None and chunk_shape is not None:
raise ValueError("Only one of chunk_shape or chunks can be provided.")

dtype = np.dtype(dtype)
dtype = parse_dtype(dtype, zarr_format)
# dtype = np.dtype(dtype)
if chunks:
_chunks = normalize_chunks(chunks, shape, dtype.itemsize)
else:
Expand Down
14 changes: 14 additions & 0 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
overload,
)

import numpy as np

from zarr.core.strings import _STRING_DTYPE

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterator

Expand Down Expand Up @@ -162,3 +166,13 @@ def parse_order(data: Any) -> Literal["C", "F"]:
if data in ("C", "F"):
return cast(Literal["C", "F"], data)
raise ValueError(f"Expected one of ('C', 'F'), got {data} instead.")


def parse_dtype(dtype: Any, zarr_format: ZarrFormat) -> np.dtype[Any]:
if dtype is str or dtype == "str":
if zarr_format == 2:
# special case as object
return np.dtype("object")
else:
return _STRING_DTYPE
return np.dtype(dtype)
11 changes: 3 additions & 8 deletions src/zarr/core/metadata/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from zarr.core.array_spec import ArraySpec
from zarr.core.chunk_grids import RegularChunkGrid
from zarr.core.chunk_key_encodings import parse_separator
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_shapelike
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike
from zarr.core.config import config, parse_indexing_order
from zarr.core.metadata.common import ArrayMetadata, parse_attributes

Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
Metadata for a Zarr version 2 array.
"""
shape_parsed = parse_shapelike(shape)
data_type_parsed = parse_dtype(dtype)
data_type_parsed = parse_dtype(dtype, zarr_format=2)
chunks_parsed = parse_shapelike(chunks)
compressor_parsed = parse_compressor(compressor)
order_parsed = parse_indexing_order(order)
Expand Down Expand Up @@ -141,7 +141,7 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
_data = data.copy()
# check that the zarr_format attribute is correct
_ = parse_zarr_format(_data.pop("zarr_format"))
dtype = parse_dtype(_data["dtype"])
dtype = parse_dtype(_data["dtype"], zarr_format=2)

if dtype.kind in "SV":
fill_value_encoded = _data.get("fill_value")
Expand Down Expand Up @@ -201,11 +201,6 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
return replace(self, attributes=attributes)


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
# todo: real validation
return np.dtype(data)


def parse_zarr_format(data: object) -> Literal[2]:
if data == 2:
return 2
Expand Down
2 changes: 1 addition & 1 deletion tests/v3/test_codecs/test_vlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from zarr.core.strings import _NUMPY_SUPPORTS_VLEN_STRING
from zarr.storage.common import StorePath

numpy_str_dtypes: list[type | None] = [None, str, np.dtypes.StrDType]
numpy_str_dtypes: list[type | str | None] = [None, str, "str", np.dtypes.StrDType]
expected_zarr_string_dtype: np.dtype[Any]
if _NUMPY_SUPPORTS_VLEN_STRING:
numpy_str_dtypes.append(np.dtypes.StringDType)
Expand Down
8 changes: 8 additions & 0 deletions tests/v3/test_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from collections.abc import Iterator
from typing import Any

import numpy as np
import pytest
Expand Down Expand Up @@ -84,3 +85,10 @@ async def test_v2_encode_decode(dtype):
data = zarr.open_array(store=store, path="foo")[:]
expected = np.full((3,), b"X", dtype=dtype)
np.testing.assert_equal(data, expected)


@pytest.mark.parametrize("dtype", [str, "str"])
async def test_create_dtype_str(dtype: Any) -> None:
arr = zarr.create(shape=10, dtype=dtype, zarr_format=2)
assert arr.dtype.kind == "O"
assert arr.metadata.to_dict()["dtype"] == "|O"