Skip to content

Commit 9f0abdb

Browse files
committed
Zarr Python v3 updates
Remove workaround for zarr-developers/zarr-python#1978
1 parent 69e9f94 commit 9f0abdb

File tree

2 files changed

+23
-36
lines changed

2 files changed

+23
-36
lines changed

cubed/storage/backends/zarr_python_v3.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import zarr
44

55
from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store
6-
from cubed.utils import join_path
76

87

98
class ZarrV3ArrayGroup(dict):
@@ -40,41 +39,29 @@ def open_zarr_v3_array(
4039
if isinstance(chunks, int):
4140
chunks = (chunks,)
4241

43-
if mode in ("r", "r+"):
44-
# TODO: remove when https://github.com/zarr-developers/zarr-python/issues/1978 is fixed
45-
if mode == "r+":
46-
mode = "w"
47-
if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None:
48-
return zarr.open(store=store, mode=mode, path=path)
42+
if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None:
43+
return zarr.open(
44+
store=store,
45+
mode=mode,
46+
shape=shape,
47+
dtype=dtype,
48+
chunks=chunks,
49+
path=path,
50+
)
51+
52+
group = zarr.open_group(store=store, mode=mode, path=path)
53+
54+
# create/open all the arrays in the group
55+
ret = ZarrV3ArrayGroup(shape=shape, dtype=dtype, chunks=chunks)
56+
for field in dtype.fields:
57+
field_dtype, _ = dtype.fields[field]
58+
if mode in ("r", "r+"):
59+
ret[field] = group[field]
4960
else:
50-
ret = ZarrV3ArrayGroup(shape=shape, dtype=dtype, chunks=chunks)
51-
for field in dtype.fields:
52-
field_dtype, _ = dtype.fields[field]
53-
field_path = field if path is None else join_path(path, field)
54-
ret[field] = zarr.open(store=store, mode=mode, path=field_path)
55-
return ret
56-
else:
57-
overwrite = True if mode == "a" else False
58-
if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None:
59-
return zarr.create(
61+
ret[field] = group.create_array(
62+
field,
6063
shape=shape,
61-
dtype=dtype,
64+
dtype=field_dtype,
6265
chunk_shape=chunks,
63-
store=store,
64-
overwrite=overwrite,
65-
path=path,
6666
)
67-
else:
68-
ret = ZarrV3ArrayGroup(shape=shape, dtype=dtype, chunks=chunks)
69-
for field in dtype.fields:
70-
field_dtype, _ = dtype.fields[field]
71-
field_path = field if path is None else join_path(path, field)
72-
ret[field] = zarr.create(
73-
shape=shape,
74-
dtype=field_dtype,
75-
chunk_shape=chunks,
76-
store=store,
77-
overwrite=overwrite,
78-
path=field_path,
79-
)
80-
return ret
67+
return ret

cubed/tests/storage/test_zarr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def test_lazy_zarr_array(tmp_path):
88
arr = lazy_zarr_array(zarr_path, shape=(3, 3), dtype=int, chunks=(2, 2))
99

1010
assert not zarr_path.exists()
11-
with pytest.raises((TypeError, ValueError)):
11+
with pytest.raises((FileNotFoundError, TypeError, ValueError)):
1212
arr.open()
1313

1414
arr.create()

0 commit comments

Comments
 (0)