Skip to content

Commit 450bcc6

Browse files
committed
merge
1 parent 76c3450 commit 450bcc6

File tree

5 files changed

+480
-165
lines changed

5 files changed

+480
-165
lines changed

src/zarr/v3/abc/codec.py

Lines changed: 132 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from __future__ import annotations
22

33
from abc import abstractmethod
4-
from typing import TYPE_CHECKING, Optional
4+
from typing import TYPE_CHECKING, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar
55

66
import numpy as np
77
from zarr.v3.abc.metadata import Metadata
88

9-
from zarr.v3.common import ArraySpec
9+
from zarr.v3.common import ArraySpec, concurrent_map
1010
from zarr.v3.store import StorePath
1111

1212

@@ -16,6 +16,22 @@
1616
from zarr.v3.metadata import ArrayMetadata
1717
from zarr.v3.config import RuntimeConfiguration
1818

19+
T = TypeVar("T")
20+
U = TypeVar("U")
21+
22+
23+
def noop_for_none(
24+
func: Callable[[Optional[T], ArraySpec, RuntimeConfiguration], Awaitable[U]]
25+
) -> Callable[[T, ArraySpec, RuntimeConfiguration], Awaitable[U]]:
26+
async def wrap(
27+
chunk: Optional[T], chunk_spec: ArraySpec, runtime_configuration: RuntimeConfiguration
28+
) -> U:
29+
if chunk is None:
30+
return None
31+
return await func(chunk, chunk_spec, runtime_configuration)
32+
33+
return wrap
34+
1935

2036
class Codec(Metadata):
2137
is_fixed_size: bool
@@ -44,6 +60,20 @@ async def decode(
4460
) -> np.ndarray:
4561
pass
4662

63+
async def decode_batch(
64+
self,
65+
chunk_arrays_and_specs: Iterable[Tuple[np.ndarray, ArraySpec]],
66+
runtime_configuration: RuntimeConfiguration,
67+
) -> Iterable[np.ndarray]:
68+
return await concurrent_map(
69+
[
70+
(chunk_array, chunk_spec, runtime_configuration)
71+
for chunk_array, chunk_spec in chunk_arrays_and_specs
72+
],
73+
noop_for_none(self.decode),
74+
runtime_configuration.concurrency,
75+
)
76+
4777
@abstractmethod
4878
async def encode(
4979
self,
@@ -53,17 +83,45 @@ async def encode(
5383
) -> Optional[np.ndarray]:
5484
pass
5585

86+
async def encode_batch(
87+
self,
88+
chunk_arrays_and_specs: Iterable[Tuple[Optional[np.ndarray], ArraySpec]],
89+
runtime_configuration: RuntimeConfiguration,
90+
) -> Iterable[Optional[np.ndarray]]:
91+
return await concurrent_map(
92+
[
93+
(chunk_array, chunk_spec, runtime_configuration)
94+
for chunk_array, chunk_spec in chunk_arrays_and_specs
95+
],
96+
noop_for_none(self.encode),
97+
runtime_configuration.concurrency,
98+
)
99+
56100

57101
class ArrayBytesCodec(Codec):
58102
@abstractmethod
59103
async def decode(
60104
self,
61-
chunk_array: BytesLike,
105+
chunk_bytes: BytesLike,
62106
chunk_spec: ArraySpec,
63107
runtime_configuration: RuntimeConfiguration,
64108
) -> np.ndarray:
65109
pass
66110

111+
async def decode_batch(
112+
self,
113+
chunk_bytes_and_specs: Iterable[Tuple[BytesLike, ArraySpec]],
114+
runtime_configuration: RuntimeConfiguration,
115+
) -> Iterable[np.ndarray]:
116+
return await concurrent_map(
117+
[
118+
(chunk_bytes, chunk_spec, runtime_configuration)
119+
for chunk_bytes, chunk_spec in chunk_bytes_and_specs
120+
],
121+
noop_for_none(self.decode),
122+
runtime_configuration.concurrency,
123+
)
124+
67125
@abstractmethod
68126
async def encode(
69127
self,
@@ -73,6 +131,20 @@ async def encode(
73131
) -> Optional[BytesLike]:
74132
pass
75133

134+
async def encode_batch(
135+
self,
136+
chunk_arrays_and_specs: Iterable[Tuple[Optional[np.ndarray], ArraySpec]],
137+
runtime_configuration: RuntimeConfiguration,
138+
) -> Iterable[Optional[BytesLike]]:
139+
return await concurrent_map(
140+
[
141+
(chunk_array, chunk_spec, runtime_configuration)
142+
for chunk_array, chunk_spec in chunk_arrays_and_specs
143+
],
144+
noop_for_none(self.encode),
145+
runtime_configuration.concurrency,
146+
)
147+
76148

77149
class ArrayBytesCodecPartialDecodeMixin:
78150
@abstractmethod
@@ -85,6 +157,20 @@ async def decode_partial(
85157
) -> Optional[np.ndarray]:
86158
pass
87159

160+
async def decode_partial_batched(
161+
self,
162+
batch_info: Iterable[Tuple[StorePath, SliceSelection, ArraySpec]],
163+
runtime_configuration: RuntimeConfiguration,
164+
) -> Iterable[Optional[np.ndarray]]:
165+
return await concurrent_map(
166+
[
167+
(store_path, selection, chunk_spec, runtime_configuration)
168+
for store_path, selection, chunk_spec in batch_info
169+
],
170+
self.decode_partial,
171+
runtime_configuration.concurrency,
172+
)
173+
88174

89175
class ArrayBytesCodecPartialEncodeMixin:
90176
@abstractmethod
@@ -98,17 +184,45 @@ async def encode_partial(
98184
) -> None:
99185
pass
100186

187+
async def encode_partial_batched(
188+
self,
189+
batch_info: Iterable[Tuple[StorePath, np.ndarray, SliceSelection, ArraySpec]],
190+
runtime_configuration: RuntimeConfiguration,
191+
) -> None:
192+
await concurrent_map(
193+
[
194+
(store_path, chunk_array, selection, chunk_spec, runtime_configuration)
195+
for store_path, chunk_array, selection, chunk_spec in batch_info
196+
],
197+
self.encode_partial,
198+
runtime_configuration.concurrency,
199+
)
200+
101201

102202
class BytesBytesCodec(Codec):
103203
@abstractmethod
104204
async def decode(
105205
self,
106-
chunk_array: BytesLike,
206+
chunk_bytes: BytesLike,
107207
chunk_spec: ArraySpec,
108208
runtime_configuration: RuntimeConfiguration,
109209
) -> BytesLike:
110210
pass
111211

212+
async def decode_batch(
213+
self,
214+
chunk_bytes_and_specs: Iterable[Tuple[BytesLike, ArraySpec]],
215+
runtime_configuration: RuntimeConfiguration,
216+
) -> Iterable[BytesLike]:
217+
return await concurrent_map(
218+
[
219+
(chunk_bytes, chunk_spec, runtime_configuration)
220+
for chunk_bytes, chunk_spec in chunk_bytes_and_specs
221+
],
222+
noop_for_none(self.decode),
223+
runtime_configuration.concurrency,
224+
)
225+
112226
@abstractmethod
113227
async def encode(
114228
self,
@@ -117,3 +231,17 @@ async def encode(
117231
runtime_configuration: RuntimeConfiguration,
118232
) -> Optional[BytesLike]:
119233
pass
234+
235+
async def encode_batch(
236+
self,
237+
chunk_bytes_and_specs: Iterable[Tuple[Optional[BytesLike], ArraySpec]],
238+
runtime_configuration: RuntimeConfiguration,
239+
) -> Iterable[Optional[BytesLike]]:
240+
return await concurrent_map(
241+
[
242+
(chunk_bytes, chunk_spec, runtime_configuration)
243+
for chunk_bytes, chunk_spec in chunk_bytes_and_specs
244+
],
245+
noop_for_none(self.encode),
246+
runtime_configuration.concurrency,
247+
)

0 commit comments

Comments
 (0)