Skip to content

Commit 5678b75

Browse files
Update map_blocks to use chunksizes property. (#6776)
* Update map_blocks to use chunksizes property. Raise nicer error if provided template has no dask arrays. Closes #6763 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typing Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f28d7f8 commit 5678b75

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

xarray/core/parallel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -373,19 +373,19 @@ def _wrapper(
373373
new_indexes = template_indexes - set(input_indexes)
374374
indexes = {dim: input_indexes[dim] for dim in preserved_indexes}
375375
indexes.update({k: template._indexes[k] for k in new_indexes})
376-
output_chunks = {
376+
output_chunks: Mapping[Hashable, tuple[int, ...]] = {
377377
dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
378378
}
379379

380380
else:
381381
# template xarray object has been provided with proper sizes and chunk shapes
382382
indexes = dict(template._indexes)
383-
if isinstance(template, DataArray):
384-
output_chunks = dict(
385-
zip(template.dims, template.chunks) # type: ignore[arg-type]
383+
output_chunks = template.chunksizes
384+
if not output_chunks:
385+
raise ValueError(
386+
"Provided template has no dask arrays. "
387+
" Please construct a template with appropriately chunked dask arrays."
386388
)
387-
else:
388-
output_chunks = dict(template.chunks)
389389

390390
for dim in output_chunks:
391391
if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]):

xarray/tests/test_dask.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,15 @@ def sumda(da1, da2):
12431243
)
12441244
xr.testing.assert_equal((da1 + da2).sum("x"), mapped)
12451245

1246+
# bad template: not chunked
1247+
with pytest.raises(ValueError, match="Provided template has no dask arrays"):
1248+
xr.map_blocks(
1249+
lambda a, b: (a + b).sum("x"),
1250+
da1,
1251+
args=[da2],
1252+
template=da1.sum("x").compute(),
1253+
)
1254+
12461255

12471256
@pytest.mark.parametrize("obj", [make_da(), make_ds()])
12481257
def test_map_blocks_add_attrs(obj):

0 commit comments

Comments
 (0)