Skip to content

Commit e9c81df

Browse files
authored
Remove duplicate code from new rechunk implementation (#702)
* Remove duplicate code from new rechunk implementation * Test rechunking plan chunk sizes for ERA5
1 parent a8d4748 commit e9c81df

File tree

3 files changed

+35
-64
lines changed

3 files changed

+35
-64
lines changed

cubed/core/ops.py

Lines changed: 10 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,13 @@ def rechunk_new(x, chunks, *, min_mem=None):
11261126
cubed.Array
11271127
An array with the desired chunks.
11281128
"""
1129+
out = x
1130+
for copy_chunks, target_chunks in _rechunk_plan(x, chunks, min_mem=min_mem):
1131+
out = _rechunk(out, copy_chunks, target_chunks)
1132+
return out
1133+
1134+
1135+
def _rechunk_plan(x, chunks, *, min_mem=None):
11291136
if isinstance(chunks, dict):
11301137
chunks = {validate_axis(c, x.ndim): v for c, v in chunks.items()}
11311138
for i in range(x.ndim):
@@ -1165,7 +1172,6 @@ def rechunk_new(x, chunks, *, min_mem=None):
11651172
max_mem=rechunker_max_mem,
11661173
)
11671174

1168-
out = x
11691175
for i, stage in enumerate(stages):
11701176
last_stage = i == len(stages) - 1
11711177
read_chunks, int_chunks, write_chunks = stage
@@ -1174,12 +1180,10 @@ def rechunk_new(x, chunks, *, min_mem=None):
11741180
target_chunks_ = target_chunks if last_stage else write_chunks
11751181

11761182
if read_chunks == write_chunks:
1177-
out = _rechunk(out, read_chunks, target_chunks_)
1183+
yield read_chunks, target_chunks_
11781184
else:
1179-
intermediate = _rechunk(out, read_chunks, int_chunks)
1180-
out = _rechunk(intermediate, write_chunks, target_chunks_)
1181-
1182-
return out
1185+
yield read_chunks, int_chunks
1186+
yield write_chunks, target_chunks_
11831187

11841188

11851189
def _rechunk(x, copy_chunks, target_chunks):
@@ -1217,62 +1221,6 @@ def selection_function(out_key):
12171221
)
12181222

12191223

1220-
def rechunk_plan(x, chunks, *, min_mem=None):
1221-
if isinstance(chunks, dict):
1222-
chunks = {validate_axis(c, x.ndim): v for c, v in chunks.items()}
1223-
for i in range(x.ndim):
1224-
if i not in chunks:
1225-
chunks[i] = x.chunks[i]
1226-
elif chunks[i] is None:
1227-
chunks[i] = x.chunks[i]
1228-
if isinstance(chunks, (tuple, list)):
1229-
chunks = tuple(lc if lc is not None else rc for lc, rc in zip(chunks, x.chunks))
1230-
1231-
normalized_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype)
1232-
if x.chunks == normalized_chunks:
1233-
return x
1234-
# normalizing takes care of dict args for chunks
1235-
target_chunks = to_chunksize(normalized_chunks)
1236-
1237-
# merge chunks special case
1238-
if all(c1 % c0 == 0 for c0, c1 in zip(x.chunksize, target_chunks)):
1239-
return merge_chunks(x, target_chunks)
1240-
1241-
spec = x.spec
1242-
source_chunks = to_chunksize(normalize_chunks(x.chunks, x.shape, dtype=x.dtype))
1243-
1244-
# rechunker doesn't take account of uncompressed and compressed copies of the
1245-
# input and output array chunk/selection, so adjust appropriately
1246-
rechunker_max_mem = (spec.allowed_mem - spec.reserved_mem) // 5
1247-
if min_mem is None:
1248-
min_mem = min(rechunker_max_mem // 20, x.nbytes)
1249-
stages = multistage_rechunking_plan(
1250-
shape=x.shape,
1251-
source_chunks=source_chunks,
1252-
target_chunks=target_chunks,
1253-
itemsize=x.dtype.itemsize,
1254-
min_mem=min_mem,
1255-
max_mem=rechunker_max_mem,
1256-
)
1257-
1258-
source_chunks = x.chunksize
1259-
for i, stage in enumerate(stages):
1260-
last_stage = i == len(stages) - 1
1261-
read_chunks, int_chunks, write_chunks = stage
1262-
1263-
# Use target chunks for last stage
1264-
target_chunks_ = target_chunks if last_stage else write_chunks
1265-
1266-
if read_chunks == write_chunks:
1267-
yield source_chunks, read_chunks, target_chunks_
1268-
source_chunks = target_chunks_
1269-
else:
1270-
yield source_chunks, read_chunks, int_chunks
1271-
source_chunks = int_chunks
1272-
yield source_chunks, write_chunks, target_chunks_
1273-
source_chunks = target_chunks_
1274-
1275-
12761224
def merge_chunks(x, chunks):
12771225
"""Merge multiple chunks into one."""
12781226
target_chunksize = chunks

cubed/tests/test_mem_utilization_rechunk.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ def test_rechunk_era5(tmp_path, spec, executor):
6666

6767
x = cubed.random.random(shape, dtype=xp.float32, chunks=source_chunks, spec=spec)
6868

69-
from cubed.core.ops import rechunk_plan
69+
from cubed.core.ops import _rechunk_plan
7070

7171
i = 0
72-
for source_chunks, copy_chunks, target_chunks in rechunk_plan(x, target_chunks):
72+
for copy_chunks, target_chunks in _rechunk_plan(x, target_chunks):
7373
# Find the smallest shape that contains the three chunk sizes
7474
# This will be a lot less than the full ERA5 shape (350640, 721, 1440),
7575
# making it suitable for running in a test
@@ -85,4 +85,5 @@ def test_rechunk_era5(tmp_path, spec, executor):
8585

8686
run_operation(tmp_path, executor, f"rechunk_era5_stage_{i}", b)
8787

88+
source_chunks = target_chunks
8889
i += 1

cubed/tests/test_rechunk.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,25 @@ def test_rechunk_era5(
5252
d["pipeline"].config.num_output_blocks[0] for _, d in rechunks
5353
)
5454
assert max_output_blocks == expected_max_output_blocks
55+
56+
57+
def test_rechunk_era5_chunk_sizes():
58+
# from https://github.com/pangeo-data/rechunker/pull/89
59+
shape = (350640, 721, 1440)
60+
source_chunks = (31, 721, 1440)
61+
target_chunks = (350640, 10, 10)
62+
63+
spec = cubed.Spec(allowed_mem="2.5GB")
64+
65+
a = xp.empty(shape, dtype=xp.float32, chunks=source_chunks, spec=spec)
66+
67+
from cubed.core.ops import _rechunk_plan
68+
69+
assert list(_rechunk_plan(a, target_chunks)) == [
70+
((93, 721, 1440), (93, 173, 396)),
71+
((1447, 173, 396), (1447, 173, 396)),
72+
((1447, 173, 396), (1447, 41, 109)),
73+
((22528, 41, 109), (22528, 41, 109)),
74+
((22528, 41, 109), (22528, 10, 30)),
75+
((350640, 10, 30), (350640, 10, 10)),
76+
]

0 commit comments

Comments
 (0)