@@ -1126,6 +1126,13 @@ def rechunk_new(x, chunks, *, min_mem=None):
1126
1126
cubed.Array
1127
1127
An array with the desired chunks.
1128
1128
"""
1129
+ out = x
1130
+ for copy_chunks , target_chunks in _rechunk_plan (x , chunks , min_mem = min_mem ):
1131
+ out = _rechunk (out , copy_chunks , target_chunks )
1132
+ return out
1133
+
1134
+
1135
+ def _rechunk_plan (x , chunks , * , min_mem = None ):
1129
1136
if isinstance (chunks , dict ):
1130
1137
chunks = {validate_axis (c , x .ndim ): v for c , v in chunks .items ()}
1131
1138
for i in range (x .ndim ):
@@ -1165,7 +1172,6 @@ def rechunk_new(x, chunks, *, min_mem=None):
1165
1172
max_mem = rechunker_max_mem ,
1166
1173
)
1167
1174
1168
- out = x
1169
1175
for i , stage in enumerate (stages ):
1170
1176
last_stage = i == len (stages ) - 1
1171
1177
read_chunks , int_chunks , write_chunks = stage
@@ -1174,12 +1180,10 @@ def rechunk_new(x, chunks, *, min_mem=None):
1174
1180
target_chunks_ = target_chunks if last_stage else write_chunks
1175
1181
1176
1182
if read_chunks == write_chunks :
1177
- out = _rechunk ( out , read_chunks , target_chunks_ )
1183
+ yield read_chunks , target_chunks_
1178
1184
else :
1179
- intermediate = _rechunk (out , read_chunks , int_chunks )
1180
- out = _rechunk (intermediate , write_chunks , target_chunks_ )
1181
-
1182
- return out
1185
+ yield read_chunks , int_chunks
1186
+ yield write_chunks , target_chunks_
1183
1187
1184
1188
1185
1189
def _rechunk (x , copy_chunks , target_chunks ):
@@ -1217,62 +1221,6 @@ def selection_function(out_key):
1217
1221
)
1218
1222
1219
1223
1220
- def rechunk_plan (x , chunks , * , min_mem = None ):
1221
- if isinstance (chunks , dict ):
1222
- chunks = {validate_axis (c , x .ndim ): v for c , v in chunks .items ()}
1223
- for i in range (x .ndim ):
1224
- if i not in chunks :
1225
- chunks [i ] = x .chunks [i ]
1226
- elif chunks [i ] is None :
1227
- chunks [i ] = x .chunks [i ]
1228
- if isinstance (chunks , (tuple , list )):
1229
- chunks = tuple (lc if lc is not None else rc for lc , rc in zip (chunks , x .chunks ))
1230
-
1231
- normalized_chunks = normalize_chunks (chunks , x .shape , dtype = x .dtype )
1232
- if x .chunks == normalized_chunks :
1233
- return x
1234
- # normalizing takes care of dict args for chunks
1235
- target_chunks = to_chunksize (normalized_chunks )
1236
-
1237
- # merge chunks special case
1238
- if all (c1 % c0 == 0 for c0 , c1 in zip (x .chunksize , target_chunks )):
1239
- return merge_chunks (x , target_chunks )
1240
-
1241
- spec = x .spec
1242
- source_chunks = to_chunksize (normalize_chunks (x .chunks , x .shape , dtype = x .dtype ))
1243
-
1244
- # rechunker doesn't take account of uncompressed and compressed copies of the
1245
- # input and output array chunk/selection, so adjust appropriately
1246
- rechunker_max_mem = (spec .allowed_mem - spec .reserved_mem ) // 5
1247
- if min_mem is None :
1248
- min_mem = min (rechunker_max_mem // 20 , x .nbytes )
1249
- stages = multistage_rechunking_plan (
1250
- shape = x .shape ,
1251
- source_chunks = source_chunks ,
1252
- target_chunks = target_chunks ,
1253
- itemsize = x .dtype .itemsize ,
1254
- min_mem = min_mem ,
1255
- max_mem = rechunker_max_mem ,
1256
- )
1257
-
1258
- source_chunks = x .chunksize
1259
- for i , stage in enumerate (stages ):
1260
- last_stage = i == len (stages ) - 1
1261
- read_chunks , int_chunks , write_chunks = stage
1262
-
1263
- # Use target chunks for last stage
1264
- target_chunks_ = target_chunks if last_stage else write_chunks
1265
-
1266
- if read_chunks == write_chunks :
1267
- yield source_chunks , read_chunks , target_chunks_
1268
- source_chunks = target_chunks_
1269
- else :
1270
- yield source_chunks , read_chunks , int_chunks
1271
- source_chunks = int_chunks
1272
- yield source_chunks , write_chunks , target_chunks_
1273
- source_chunks = target_chunks_
1274
-
1275
-
1276
1224
def merge_chunks (x , chunks ):
1277
1225
"""Merge multiple chunks into one."""
1278
1226
target_chunksize = chunks
0 commit comments