Skip to content

Commit 54060d3

Browse files
[v3] Implement Group methods for empty, full, ones, and zeros (#2210)
* fill in stubs for Group.{empty,zeros,ones,full,empty_like, zeros_like,onest_like,full_like} * add shape to function signature * change type in function signature and add unit tests * precommit * small fixes * add shape check to tests * update function signatures * cast path to a str * update store path --------- Co-authored-by: Joe Hamman <[email protected]>
1 parent 06e3215 commit 54060d3

File tree

3 files changed

+128
-34
lines changed

3 files changed

+128
-34
lines changed

src/zarr/api/asynchronous.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ async def create(
575575
chunks: ChunkCoords | None = None, # TODO: v2 allowed chunks=True
576576
dtype: npt.DTypeLike | None = None,
577577
compressor: dict[str, JSON] | None = None, # TODO: default and type change
578-
fill_value: Any = 0, # TODO: need type
578+
fill_value: Any | None = 0, # TODO: need type
579579
order: MemoryOrder | None = None, # TODO: default change
580580
store: str | StoreLike | None = None,
581581
synchronizer: Any | None = None,
@@ -827,7 +827,7 @@ async def full_like(a: ArrayLike, **kwargs: Any) -> AsyncArray:
827827
"""
828828
like_kwargs = _like_args(a, kwargs)
829829
if isinstance(a, AsyncArray):
830-
kwargs.setdefault("fill_value", a.metadata.fill_value)
830+
like_kwargs.setdefault("fill_value", a.metadata.fill_value)
831831
return await full(**like_kwargs)
832832

833833

src/zarr/core/group.py

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy.typing as npt
1111
from typing_extensions import deprecated
1212

13+
import zarr.api.asynchronous as async_api
1314
from zarr.abc.metadata import Metadata
1415
from zarr.abc.store import set_or_delete
1516
from zarr.core.array import Array, AsyncArray
@@ -704,29 +705,41 @@ async def arrays(self) -> AsyncGenerator[AsyncArray, None]:
704705
async def tree(self, expand: bool = False, level: int | None = None) -> Any:
705706
raise NotImplementedError
706707

707-
async def empty(self, **kwargs: Any) -> AsyncArray:
708-
raise NotImplementedError
708+
async def empty(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> AsyncArray:
709+
return await async_api.empty(shape=shape, store=self.store_path, path=name, **kwargs)
709710

710-
async def zeros(self, **kwargs: Any) -> AsyncArray:
711-
raise NotImplementedError
711+
async def zeros(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> AsyncArray:
712+
return await async_api.zeros(shape=shape, store=self.store_path, path=name, **kwargs)
712713

713-
async def ones(self, **kwargs: Any) -> AsyncArray:
714-
raise NotImplementedError
714+
async def ones(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> AsyncArray:
715+
return await async_api.ones(shape=shape, store=self.store_path, path=name, **kwargs)
715716

716-
async def full(self, **kwargs: Any) -> AsyncArray:
717-
raise NotImplementedError
717+
async def full(
718+
self, *, name: str, shape: ChunkCoords, fill_value: Any | None, **kwargs: Any
719+
) -> AsyncArray:
720+
return await async_api.full(
721+
shape=shape, fill_value=fill_value, store=self.store_path, path=name, **kwargs
722+
)
718723

719-
async def empty_like(self, prototype: AsyncArray, **kwargs: Any) -> AsyncArray:
720-
raise NotImplementedError
724+
async def empty_like(
725+
self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any
726+
) -> AsyncArray:
727+
return await async_api.empty_like(a=prototype, store=self.store_path, path=name, **kwargs)
721728

722-
async def zeros_like(self, prototype: AsyncArray, **kwargs: Any) -> AsyncArray:
723-
raise NotImplementedError
729+
async def zeros_like(
730+
self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any
731+
) -> AsyncArray:
732+
return await async_api.zeros_like(a=prototype, store=self.store_path, path=name, **kwargs)
724733

725-
async def ones_like(self, prototype: AsyncArray, **kwargs: Any) -> AsyncArray:
726-
raise NotImplementedError
734+
async def ones_like(
735+
self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any
736+
) -> AsyncArray:
737+
return await async_api.ones_like(a=prototype, store=self.store_path, path=name, **kwargs)
727738

728-
async def full_like(self, prototype: AsyncArray, **kwargs: Any) -> AsyncArray:
729-
raise NotImplementedError
739+
async def full_like(
740+
self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any
741+
) -> AsyncArray:
742+
return await async_api.full_like(a=prototype, store=self.store_path, path=name, **kwargs)
730743

731744
async def move(self, source: str, dest: str) -> None:
732745
raise NotImplementedError
@@ -1058,29 +1071,43 @@ def require_array(self, name: str, **kwargs: Any) -> Array:
10581071
"""
10591072
return Array(self._sync(self._async_group.require_array(name, **kwargs)))
10601073

1061-
def empty(self, **kwargs: Any) -> Array:
1062-
return Array(self._sync(self._async_group.empty(**kwargs)))
1074+
def empty(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> Array:
1075+
return Array(self._sync(self._async_group.empty(name=name, shape=shape, **kwargs)))
10631076

1064-
def zeros(self, **kwargs: Any) -> Array:
1065-
return Array(self._sync(self._async_group.zeros(**kwargs)))
1077+
def zeros(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> Array:
1078+
return Array(self._sync(self._async_group.zeros(name=name, shape=shape, **kwargs)))
10661079

1067-
def ones(self, **kwargs: Any) -> Array:
1068-
return Array(self._sync(self._async_group.ones(**kwargs)))
1080+
def ones(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> Array:
1081+
return Array(self._sync(self._async_group.ones(name=name, shape=shape, **kwargs)))
10691082

1070-
def full(self, **kwargs: Any) -> Array:
1071-
return Array(self._sync(self._async_group.full(**kwargs)))
1083+
def full(
1084+
self, *, name: str, shape: ChunkCoords, fill_value: Any | None, **kwargs: Any
1085+
) -> Array:
1086+
return Array(
1087+
self._sync(
1088+
self._async_group.full(name=name, shape=shape, fill_value=fill_value, **kwargs)
1089+
)
1090+
)
10721091

1073-
def empty_like(self, prototype: AsyncArray, **kwargs: Any) -> Array:
1074-
return Array(self._sync(self._async_group.empty_like(prototype, **kwargs)))
1092+
def empty_like(self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any) -> Array:
1093+
return Array(
1094+
self._sync(self._async_group.empty_like(name=name, prototype=prototype, **kwargs))
1095+
)
10751096

1076-
def zeros_like(self, prototype: AsyncArray, **kwargs: Any) -> Array:
1077-
return Array(self._sync(self._async_group.zeros_like(prototype, **kwargs)))
1097+
def zeros_like(self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any) -> Array:
1098+
return Array(
1099+
self._sync(self._async_group.zeros_like(name=name, prototype=prototype, **kwargs))
1100+
)
10781101

1079-
def ones_like(self, prototype: AsyncArray, **kwargs: Any) -> Array:
1080-
return Array(self._sync(self._async_group.ones_like(prototype, **kwargs)))
1102+
def ones_like(self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any) -> Array:
1103+
return Array(
1104+
self._sync(self._async_group.ones_like(name=name, prototype=prototype, **kwargs))
1105+
)
10811106

1082-
def full_like(self, prototype: AsyncArray, **kwargs: Any) -> Array:
1083-
return Array(self._sync(self._async_group.full_like(prototype, **kwargs)))
1107+
def full_like(self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any) -> Array:
1108+
return Array(
1109+
self._sync(self._async_group.full_like(name=name, prototype=prototype, **kwargs))
1110+
)
10841111

10851112
def move(self, source: str, dest: str) -> None:
10861113
return self._sync(self._async_group.move(source, dest))

tests/v3/test_group.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,73 @@ def test_group_create_array(
390390
assert np.array_equal(array[:], data)
391391

392392

393+
def test_group_array_creation(
394+
store: Store,
395+
zarr_format: ZarrFormat,
396+
):
397+
group = Group.create(store, zarr_format=zarr_format)
398+
shape = (10, 10)
399+
empty_array = group.empty(name="empty", shape=shape)
400+
assert isinstance(empty_array, Array)
401+
assert empty_array.fill_value == 0
402+
assert empty_array.shape == shape
403+
assert empty_array.store_path.store == store
404+
405+
empty_like_array = group.empty_like(name="empty_like", prototype=empty_array)
406+
assert isinstance(empty_like_array, Array)
407+
assert empty_like_array.fill_value == 0
408+
assert empty_like_array.shape == shape
409+
assert empty_like_array.store_path.store == store
410+
411+
empty_array_bool = group.empty(name="empty_bool", shape=shape, dtype=np.dtype("bool"))
412+
assert isinstance(empty_array_bool, Array)
413+
assert not empty_array_bool.fill_value
414+
assert empty_array_bool.shape == shape
415+
assert empty_array_bool.store_path.store == store
416+
417+
empty_like_array_bool = group.empty_like(name="empty_like_bool", prototype=empty_array_bool)
418+
assert isinstance(empty_like_array_bool, Array)
419+
assert not empty_like_array_bool.fill_value
420+
assert empty_like_array_bool.shape == shape
421+
assert empty_like_array_bool.store_path.store == store
422+
423+
zeros_array = group.zeros(name="zeros", shape=shape)
424+
assert isinstance(zeros_array, Array)
425+
assert zeros_array.fill_value == 0
426+
assert zeros_array.shape == shape
427+
assert zeros_array.store_path.store == store
428+
429+
zeros_like_array = group.zeros_like(name="zeros_like", prototype=zeros_array)
430+
assert isinstance(zeros_like_array, Array)
431+
assert zeros_like_array.fill_value == 0
432+
assert zeros_like_array.shape == shape
433+
assert zeros_like_array.store_path.store == store
434+
435+
ones_array = group.ones(name="ones", shape=shape)
436+
assert isinstance(ones_array, Array)
437+
assert ones_array.fill_value == 1
438+
assert ones_array.shape == shape
439+
assert ones_array.store_path.store == store
440+
441+
ones_like_array = group.ones_like(name="ones_like", prototype=ones_array)
442+
assert isinstance(ones_like_array, Array)
443+
assert ones_like_array.fill_value == 1
444+
assert ones_like_array.shape == shape
445+
assert ones_like_array.store_path.store == store
446+
447+
full_array = group.full(name="full", shape=shape, fill_value=42)
448+
assert isinstance(full_array, Array)
449+
assert full_array.fill_value == 42
450+
assert full_array.shape == shape
451+
assert full_array.store_path.store == store
452+
453+
full_like_array = group.full_like(name="full_like", prototype=full_array, fill_value=43)
454+
assert isinstance(full_like_array, Array)
455+
assert full_like_array.fill_value == 43
456+
assert full_like_array.shape == shape
457+
assert full_like_array.store_path.store == store
458+
459+
393460
@pytest.mark.parametrize("store", ("local", "memory", "zip"), indirect=["store"])
394461
@pytest.mark.parametrize("zarr_format", (2, 3))
395462
@pytest.mark.parametrize("exists_ok", [True, False])

0 commit comments

Comments
 (0)