@@ -60,6 +60,7 @@ def _reproject_dispatcher(
60
60
shape_out ,
61
61
wcs_out ,
62
62
block_size = None ,
63
+ non_reprojected_dims = None ,
63
64
array_out = None ,
64
65
return_footprint = True ,
65
66
output_footprint = None ,
@@ -92,6 +93,11 @@ def _reproject_dispatcher(
92
93
the block size automatically determined. If ``block_size`` is not
93
94
specified or set to `None`, the reprojection will not be carried out in
94
95
blocks.
96
+ non_reprojected_dims : tuple
97
+ Dimensions that should not be reprojected but instead for which a
98
+ 1-to-1 mapping between input and output pixel space should be assumed.
99
+ By default, this is any leading extra dimensions if the input WCS has
100
+ fewer dimensions than the input data.
95
101
array_out : `~numpy.ndarray`, optional
96
102
An array in which to store the reprojected data. This can be any numpy
97
103
array including a memory map, which may be helpful when dealing with
@@ -198,9 +204,32 @@ def _reproject_dispatcher(
198
204
# shape_out will be the full size of the output array as this is updated
199
205
# in parse_output_projection, even if shape_out was originally passed in as
200
206
# the shape of a single image.
201
- broadcasting = wcs_in .low_level_wcs .pixel_n_dim < len (shape_out )
207
+ if non_reprojected_dims is None :
208
+ non_reprojected_dims = list (range (len (shape_out ) - wcs_in .low_level_wcs .pixel_n_dim ))
209
+ else :
210
+ non_reprojected_dims = list (non_reprojected_dims )
211
+
212
+ broadcasting = len (non_reprojected_dims ) > 0
213
+
214
+ reprojected_dims = [x for x in range (len (shape_out )) if x not in non_reprojected_dims ]
202
215
203
216
logger .info (f"Broadcasting is { '' if broadcasting else 'not ' } being used" )
217
+ logger .info (f"Dimensions being reprojected: { reprojected_dims } " )
218
+ logger .info (f"Dimensions not being reprojected: { non_reprojected_dims } " )
219
+
220
+ if len (block_size ) < len (shape_out ):
221
+ block_size = [- 1 ] * (len (shape_out ) - len (block_size )) + list (block_size )
222
+ elif len (block_size ) > len (shape_out ):
223
+ raise ValueError (
224
+ f"block_size { len (block_size )} cannot have more elements "
225
+ f"than the dimensionality of the output ({ len (shape_out )} )"
226
+ )
227
+
228
+ block_size = np .array (block_size )
229
+ shape_out = np .array (shape_out )
230
+
231
+ # TODO: replace block size of -1 by actual value for logic below to work
232
+ # TODO: re-implement block_size auto
204
233
205
234
# Check block size and determine whether block size indicates we should
206
235
# parallelize over broadcasted dimension. The logic is as follows: if
@@ -212,33 +241,23 @@ def _reproject_dispatcher(
212
241
# don't make any assumptions for now and assume a single chunk in the
213
242
# missing dimensions.
214
243
broadcasted_parallelization = False
215
- if broadcasting and block_size is not None and block_size != "auto" :
216
- if len (block_size ) == len (shape_out ):
217
- if (
218
- block_size [- wcs_in .low_level_wcs .pixel_n_dim :]
219
- == shape_out [- wcs_in .low_level_wcs .pixel_n_dim :]
220
- ):
221
- broadcasted_parallelization = True
222
- block_size = (
223
- block_size [: - wcs_in .low_level_wcs .pixel_n_dim ]
224
- + (- 1 ,) * wcs_in .low_level_wcs .pixel_n_dim
225
- )
226
- else :
227
- for i in range (len (shape_out ) - wcs_in .low_level_wcs .pixel_n_dim ):
228
- if block_size [i ] != - 1 and block_size [i ] != shape_out [i ]:
229
- raise ValueError (
230
- "block shape should either match output data shape along broadcasted dimension or non-broadcasted dimensions"
231
- )
232
- elif len (block_size ) < len (shape_out ):
233
- block_size = [- 1 ] * (len (shape_out ) - len (block_size )) + list (block_size )
234
- else :
235
- raise ValueError (
236
- f"block_size { len (block_size )} cannot have more elements "
237
- f"than the dimensionality of the output ({ len (shape_out )} )"
244
+ if broadcasting and block_size is not None :
245
+ if np .all (block_size [reprojected_dims ] == shape_out [reprojected_dims ]):
246
+ broadcasted_parallelization = True
247
+ block_size = np .array (
248
+ tuple (block_size [non_reprojected_dims ].tolist ())
249
+ + (- 1 ,) * len (reprojected_dims )
238
250
)
251
+ elif np .all (block_size [non_reprojected_dims ] != shape_out [non_reprojected_dims ]):
252
+ raise ValueError (
253
+ "block shape should either match output data shape along broadcasted dimension or non-broadcasted dimensions"
254
+ )
239
255
240
256
# TODO: check for shape_out not matching shape_in along broadcasted dimensions
241
257
258
+ block_size = tuple (block_size .tolist ())
259
+ shape_out = tuple (shape_out .tolist ())
260
+
242
261
logger .info (
243
262
f"{ 'P' if broadcasted_parallelization else 'Not p' } arallelizing along "
244
263
f"broadcasted dimension ({ block_size = } , { shape_out = } )"
@@ -270,17 +289,38 @@ def reproject_single_block(a, array_or_path, block_info=None):
270
289
wcs_in_cp = wcs_in .deepcopy () if isinstance (wcs_in , WCS ) else wcs_in
271
290
wcs_out_cp = wcs_out .deepcopy () if isinstance (wcs_out , WCS ) else wcs_out
272
291
273
- slices = [
274
- slice (* x ) for x in block_info [None ]["array-location" ][- wcs_out_cp .pixel_n_dim :]
275
- ]
292
+ print (block_info [None ]["array-location" ])
276
293
277
- if isinstance (wcs_out , BaseHighLevelWCS ):
294
+ slices = []
295
+ for i in reprojected_dims :
296
+ slices .append (slice (* block_info [None ]["array-location" ][i ]))
297
+
298
+ print (slices )
299
+
300
+ if isinstance (wcs_out_cp , BaseHighLevelWCS ):
278
301
low_level_wcs = SlicedLowLevelWCS (wcs_out_cp .low_level_wcs , slices = slices )
279
302
else :
280
303
low_level_wcs = SlicedLowLevelWCS (wcs_out_cp , slices = slices )
281
304
305
+ print (low_level_wcs .pixel_n_dim , low_level_wcs .world_n_dim )
306
+
282
307
wcs_out_sub = HighLevelWCSWrapper (low_level_wcs )
283
308
309
+ slices = []
310
+ for i in range (wcs_in_cp .pixel_n_dim ):
311
+ if i in non_reprojected_dims :
312
+ # slices.append(slice(*block_info[None]["array-location"][i]))
313
+ slices .append (block_info [None ]["array-location" ][i ][0 ])
314
+ else :
315
+ slices .append (slice (None ))
316
+
317
+ if isinstance (wcs_in_cp , BaseHighLevelWCS ):
318
+ low_level_wcs_in = SlicedLowLevelWCS (wcs_in_cp .low_level_wcs , slices = slices )
319
+ else :
320
+ low_level_wcs_in = SlicedLowLevelWCS (wcs_in_cp , slices = slices )
321
+
322
+ wcs_in_sub = HighLevelWCSWrapper (low_level_wcs_in )
323
+
284
324
if isinstance (array_or_path , tuple ):
285
325
array_in = np .memmap (array_or_path [0 ], ** array_or_path [1 ], mode = "r" )
286
326
elif isinstance (array_or_path , str ):
@@ -295,7 +335,7 @@ def reproject_single_block(a, array_or_path, block_info=None):
295
335
296
336
array , footprint = reproject_func (
297
337
array_in ,
298
- wcs_in_cp ,
338
+ wcs_in_sub ,
299
339
wcs_out_sub ,
300
340
shape_out = shape_out ,
301
341
array_out = np .zeros (shape_out ),
@@ -308,10 +348,11 @@ def reproject_single_block(a, array_or_path, block_info=None):
308
348
309
349
array_out_dask = da .empty (shape_out , chunks = block_size )
310
350
if isinstance (array_in , da .core .Array ):
351
+ # FIXME: Should take into account -1s here
311
352
if array_in .chunksize != block_size :
312
353
logger .info (
313
354
f"Rechunking input dask array as chunks ({ array_in .chunksize } ) "
314
- "do not match block size ({block_size})"
355
+ f "do not match block size ({ block_size } )"
315
356
)
316
357
array_in = array_in .rechunk (block_size )
317
358
else :
0 commit comments