Skip to content

Commit 12292e6

Browse files
committed
Fix map_blocks HLG layering
This fixes an issue with the HighLevelGraph noted in pydata#3584, and exposed by a recent change in Dask to do more HLG fusion.
1 parent 87a25b6 commit 12292e6

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

xarray/core/parallel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
except ImportError:
88
pass
99

10+
import collections
1011
import itertools
1112
import operator
1213
from typing import (
@@ -222,6 +223,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
222223
indexes.update({k: template.indexes[k] for k in new_indexes})
223224

224225
graph: Dict[Any, Any] = {}
226+
new_layers = collections.defaultdict(dict)
225227
gname = "{}-{}".format(
226228
dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs)
227229
)
@@ -310,10 +312,14 @@ def _wrapper(func, obj, to_array, args, kwargs):
310312
# unchunked dimensions in the input have one chunk in the result
311313
key += (0,)
312314

313-
graph[key] = (operator.getitem, from_wrapper, name)
315+
new_layers[gname_l][key] = (operator.getitem, from_wrapper, name)
314316

315317
graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset])
316318

319+
for gname_l, layer in new_layers.items():
320+
graph.dependencies[gname_l] = {gname}
321+
graph.layers[gname_l] = layer
322+
317323
result = Dataset(coords=indexes, attrs=template.attrs)
318324
for name, gname_l in var_key_map.items():
319325
dims = template[name].dims

xarray/tests/test_dask.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,13 @@ def func(obj):
11891189
assert_identical(expected.compute(), actual.compute())
11901190

11911191

1192+
def test_map_blocks_hlg_layers():
1193+
ds = xr.Dataset({'x': (('y',), dask.array.ones(10, chunks=(5,)))})
1194+
mapped = ds.map_blocks(lambda x: x)
1195+
1196+
xr.testing.assert_equal(mapped, ds) # does not work
1197+
1198+
11921199
def test_make_meta(map_ds):
11931200
from ..core.parallel import make_meta
11941201

0 commit comments

Comments
 (0)