From 0ef21c038c1eb79602f5d59e5fa404db7eea8824 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 12 Jul 2022 09:14:34 -0600 Subject: [PATCH 1/3] Update map_blocks to use chunksizes property. Raise nicer error if provided template has no dask arrays. Closes #6763 --- xarray/core/parallel.py | 10 +++++----- xarray/tests/test_dask.py | 6 ++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index fd1f3f9e999..9c778b2ee5a 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -380,12 +380,12 @@ def _wrapper( else: # template xarray object has been provided with proper sizes and chunk shapes indexes = dict(template._indexes) - if isinstance(template, DataArray): - output_chunks = dict( - zip(template.dims, template.chunks) # type: ignore[arg-type] + output_chunks = template.chunksizes + if not output_chunks: + raise ValueError( + "Provided template has no dask arrays. " + " Please construct a template with appropriately chunked dask arrays." ) - else: - output_chunks = dict(template.chunks) for dim in output_chunks: if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 0d6ee7e503a..82d1df6ea9b 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1243,6 +1243,12 @@ def sumda(da1, da2): ) xr.testing.assert_equal((da1 + da2).sum("x"), mapped) + # bad template: not chunked + with pytest.raises(ValueError, match="Provided template has no dask arrays"): + xr.map_blocks( + lambda a, b: (a + b).sum("x"), da1, args=[da2], template=da1.sum("x").compute() + ) + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_add_attrs(obj): From f3cf3b67083c7c6a7866dd388eece81adcf84c4a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Jul 2022 15:18:25 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/tests/test_dask.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 82d1df6ea9b..51845b2159e 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1246,7 +1246,10 @@ def sumda(da1, da2): # bad template: not chunked with pytest.raises(ValueError, match="Provided template has no dask arrays"): xr.map_blocks( - lambda a, b: (a + b).sum("x"), da1, args=[da2], template=da1.sum("x").compute() + lambda a, b: (a + b).sum("x"), + da1, + args=[da2], + template=da1.sum("x").compute(), ) From 72ec730fc88fa206e6297314d0ac7b7bd1579d7a Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 12 Jul 2022 09:31:35 -0600 Subject: [PATCH 3/3] fix typing --- xarray/core/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 9c778b2ee5a..2e3aff68a26 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -373,7 +373,7 @@ def _wrapper( new_indexes = template_indexes - set(input_indexes) indexes = {dim: input_indexes[dim] for dim in preserved_indexes} indexes.update({k: template._indexes[k] for k in new_indexes}) - output_chunks = { + output_chunks: Mapping[Hashable, tuple[int, ...]] = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks }