Skip to content

Commit ddcca20

Browse files
committed
WIP
1 parent 1f98c78 commit ddcca20

File tree

3 files changed

+74
-31
lines changed

3 files changed

+74
-31
lines changed

reproject/common.py

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def _reproject_dispatcher(
6060
shape_out,
6161
wcs_out,
6262
block_size=None,
63+
non_reprojected_dims=None,
6364
array_out=None,
6465
return_footprint=True,
6566
output_footprint=None,
@@ -92,6 +93,11 @@ def _reproject_dispatcher(
9293
the block size automatically determined. If ``block_size`` is not
9394
specified or set to `None`, the reprojection will not be carried out in
9495
blocks.
96+
non_reprojected_dims : tuple
97+
Dimensions that should not be reprojected but instead for which a
98+
1-to-1 mapping between input and output pixel space should be assumed.
99+
By default, this is any leading extra dimensions if the input WCS has
100+
fewer dimensions than the input data.
95101
array_out : `~numpy.ndarray`, optional
96102
An array in which to store the reprojected data. This can be any numpy
97103
array including a memory map, which may be helpful when dealing with
@@ -198,9 +204,32 @@ def _reproject_dispatcher(
198204
# shape_out will be the full size of the output array as this is updated
199205
# in parse_output_projection, even if shape_out was originally passed in as
200206
# the shape of a single image.
201-
broadcasting = wcs_in.low_level_wcs.pixel_n_dim < len(shape_out)
207+
if non_reprojected_dims is None:
208+
non_reprojected_dims = list(range(len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim))
209+
else:
210+
non_reprojected_dims = list(non_reprojected_dims)
211+
212+
broadcasting = len(non_reprojected_dims) > 0
213+
214+
reprojected_dims = [x for x in range(len(shape_out)) if x not in non_reprojected_dims]
202215

203216
logger.info(f"Broadcasting is {'' if broadcasting else 'not '}being used")
217+
logger.info(f"Dimensions being reprojected: {reprojected_dims}")
218+
logger.info(f"Dimensions not being reprojected: {non_reprojected_dims}")
219+
220+
if len(block_size) < len(shape_out):
221+
block_size = [-1] * (len(shape_out) - len(block_size)) + list(block_size)
222+
elif len(block_size) > len(shape_out):
223+
raise ValueError(
224+
f"block_size {len(block_size)} cannot have more elements "
225+
f"than the dimensionality of the output ({len(shape_out)})"
226+
)
227+
228+
block_size = np.array(block_size)
229+
shape_out = np.array(shape_out)
230+
231+
# TODO: replace block size of -1 by actual value for logic below to work
232+
# TODO: re-implement block_size auto
204233

205234
# Check block size and determine whether block size indicates we should
206235
# parallelize over broadcasted dimension. The logic is as follows: if
@@ -212,33 +241,23 @@ def _reproject_dispatcher(
212241
# don't make any assumptions for now and assume a single chunk in the
213242
# missing dimensions.
214243
broadcasted_parallelization = False
215-
if broadcasting and block_size is not None and block_size != "auto":
216-
if len(block_size) == len(shape_out):
217-
if (
218-
block_size[-wcs_in.low_level_wcs.pixel_n_dim :]
219-
== shape_out[-wcs_in.low_level_wcs.pixel_n_dim :]
220-
):
221-
broadcasted_parallelization = True
222-
block_size = (
223-
block_size[: -wcs_in.low_level_wcs.pixel_n_dim]
224-
+ (-1,) * wcs_in.low_level_wcs.pixel_n_dim
225-
)
226-
else:
227-
for i in range(len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim):
228-
if block_size[i] != -1 and block_size[i] != shape_out[i]:
229-
raise ValueError(
230-
"block shape should either match output data shape along broadcasted dimension or non-broadcasted dimensions"
231-
)
232-
elif len(block_size) < len(shape_out):
233-
block_size = [-1] * (len(shape_out) - len(block_size)) + list(block_size)
234-
else:
235-
raise ValueError(
236-
f"block_size {len(block_size)} cannot have more elements "
237-
f"than the dimensionality of the output ({len(shape_out)})"
244+
if broadcasting and block_size is not None:
245+
if np.all(block_size[reprojected_dims] == shape_out[reprojected_dims]):
246+
broadcasted_parallelization = True
247+
block_size = np.array(
248+
tuple(block_size[non_reprojected_dims].tolist())
249+
+ (-1,) * len(reprojected_dims)
238250
)
251+
elif np.all(block_size[non_reprojected_dims] != shape_out[non_reprojected_dims]):
252+
raise ValueError(
253+
"block shape should either match output data shape along broadcasted dimension or non-broadcasted dimensions"
254+
)
239255

240256
# TODO: check for shape_out not matching shape_in along broadcasted dimensions
241257

258+
block_size = tuple(block_size.tolist())
259+
shape_out = tuple(shape_out.tolist())
260+
242261
logger.info(
243262
f"{'P' if broadcasted_parallelization else 'Not p'}arallelizing along "
244263
f"broadcasted dimension ({block_size=}, {shape_out=})"
@@ -270,17 +289,38 @@ def reproject_single_block(a, array_or_path, block_info=None):
270289
wcs_in_cp = wcs_in.deepcopy() if isinstance(wcs_in, WCS) else wcs_in
271290
wcs_out_cp = wcs_out.deepcopy() if isinstance(wcs_out, WCS) else wcs_out
272291

273-
slices = [
274-
slice(*x) for x in block_info[None]["array-location"][-wcs_out_cp.pixel_n_dim :]
275-
]
292+
print(block_info[None]["array-location"])
276293

277-
if isinstance(wcs_out, BaseHighLevelWCS):
294+
slices = []
295+
for i in reprojected_dims:
296+
slices.append(slice(*block_info[None]["array-location"][i]))
297+
298+
print(slices)
299+
300+
if isinstance(wcs_out_cp, BaseHighLevelWCS):
278301
low_level_wcs = SlicedLowLevelWCS(wcs_out_cp.low_level_wcs, slices=slices)
279302
else:
280303
low_level_wcs = SlicedLowLevelWCS(wcs_out_cp, slices=slices)
281304

305+
print(low_level_wcs.pixel_n_dim, low_level_wcs.world_n_dim)
306+
282307
wcs_out_sub = HighLevelWCSWrapper(low_level_wcs)
283308

309+
slices = []
310+
for i in range(wcs_in_cp.pixel_n_dim):
311+
if i in non_reprojected_dims:
312+
# slices.append(slice(*block_info[None]["array-location"][i]))
313+
slices.append(block_info[None]["array-location"][i][0])
314+
else:
315+
slices.append(slice(None))
316+
317+
if isinstance(wcs_in_cp, BaseHighLevelWCS):
318+
low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp.low_level_wcs, slices=slices)
319+
else:
320+
low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp, slices=slices)
321+
322+
wcs_in_sub = HighLevelWCSWrapper(low_level_wcs_in)
323+
284324
if isinstance(array_or_path, tuple):
285325
array_in = np.memmap(array_or_path[0], **array_or_path[1], mode="r")
286326
elif isinstance(array_or_path, str):
@@ -295,7 +335,7 @@ def reproject_single_block(a, array_or_path, block_info=None):
295335

296336
array, footprint = reproject_func(
297337
array_in,
298-
wcs_in_cp,
338+
wcs_in_sub,
299339
wcs_out_sub,
300340
shape_out=shape_out,
301341
array_out=np.zeros(shape_out),
@@ -308,10 +348,11 @@ def reproject_single_block(a, array_or_path, block_info=None):
308348

309349
array_out_dask = da.empty(shape_out, chunks=block_size)
310350
if isinstance(array_in, da.core.Array):
351+
# FIXME: Should take into account -1s here
311352
if array_in.chunksize != block_size:
312353
logger.info(
313354
f"Rechunking input dask array as chunks ({array_in.chunksize}) "
314-
"do not match block size ({block_size})"
355+
f"do not match block size ({block_size})"
315356
)
316357
array_in = array_in.rechunk(block_size)
317358
else:

reproject/interpolation/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
def _validate_wcs(wcs_in, wcs_out, shape_in, shape_out):
1212
if wcs_in.low_level_wcs.pixel_n_dim != wcs_out.low_level_wcs.pixel_n_dim:
13-
raise ValueError("Number of dimensions in input and output WCS should match")
13+
raise ValueError(f"Number of dimensions in input and output WCS should match (got {wcs_in.low_level_wcs.pixel_n_dim} and {wcs_out.low_level_wcs.pixel_n_dim})")
1414
elif len(shape_out) < wcs_out.low_level_wcs.pixel_n_dim:
1515
raise ValueError("Too few dimensions in shape_out")
1616
elif len(shape_in) < wcs_in.low_level_wcs.pixel_n_dim:

reproject/interpolation/high_level.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def reproject_interp(
2525
output_footprint=None,
2626
return_footprint=True,
2727
block_size=None,
28+
non_reprojected_dims=None,
2829
parallel=False,
2930
return_type=None,
3031
):
@@ -142,6 +143,7 @@ def reproject_interp(
142143
array_out=output_array,
143144
parallel=parallel,
144145
block_size=block_size,
146+
non_reprojected_dims=non_reprojected_dims,
145147
return_footprint=return_footprint,
146148
output_footprint=output_footprint,
147149
reproject_func_kwargs=dict(

0 commit comments

Comments
 (0)