Skip to content

Commit 13c09dc

Browse files
TomAugspurgerdcherianmax-sixty
authored
Fixed dask.optimize on datasets (#4438)
* Fixed dask.optimize on datasets Another attempt to fix #3698. The issue with my fix in is that we hit `Variable._dask_finalize` in both `dask.optimize` and `dask.persist`. We want to do the culling of unnecessary tasks (`test_persist_Dataset`) but only in the persist case, not optimize (`test_optimize`). * Update whats-new.rst * Update doc/whats-new.rst Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Maximilian Roos <[email protected]>
1 parent 0c26211 commit 13c09dc

File tree

4 files changed

+20
-5
lines changed

4 files changed

+20
-5
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ Bug fixes
101101
By `Jens Svensmark <https://github.com/jenssss>`_
102102
- Fix incorrect legend labels for :py:meth:`Dataset.plot.scatter` (:issue:`4126`).
103103
By `Peter Hausamann <https://github.com/phausamann>`_.
104+
- Fix ``dask.optimize`` on ``DataArray`` producing an invalid Dask task graph (:issue:`3698`)
105+
By `Tom Augspurger <https://github.com/TomAugspurger>`_
104106
- Fix ``pip install .`` when no ``.git`` directory exists; namely when the xarray source
105107
directory has been rsync'ed by PyCharm Professional for a remote deployment over SSH.
106108
By `Guido Imperiale <https://github.com/crusaderky>`_
@@ -109,7 +111,6 @@ Bug fixes
109111
- Avoid relying on :py:class:`set` objects for the ordering of the coordinates (:pull:`4409`)
110112
By `Justus Magin <https://github.com/keewis>`_.
111113

112-
113114
Documentation
114115
~~~~~~~~~~~~~
115116

xarray/core/dataset.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,10 +777,19 @@ def _dask_postcompute(results, info, *args):
777777
@staticmethod
778778
def _dask_postpersist(dsk, info, *args):
779779
variables = {}
780+
# postpersist is called in both dask.optimize and dask.persist
781+
# When persisting, we want to filter out unrelated keys for
782+
# each Variable's task graph.
783+
is_persist = len(dsk) == len(info)
780784
for is_dask, k, v in info:
781785
if is_dask:
782786
func, args2 = v
783-
result = func(dsk, *args2)
787+
if is_persist:
788+
name = args2[1][0]
789+
dsk2 = {k: v for k, v in dsk.items() if k[0] == name}
790+
else:
791+
dsk2 = dsk
792+
result = func(dsk2, *args2)
784793
else:
785794
result = v
786795
variables[k] = result

xarray/core/variable.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,6 @@ def __dask_postpersist__(self):
501501

502502
@staticmethod
503503
def _dask_finalize(results, array_func, array_args, dims, attrs, encoding):
504-
if isinstance(results, dict): # persist case
505-
name = array_args[0]
506-
results = {k: v for k, v in results.items() if k[0] == name}
507504
data = array_func(results, *array_args)
508505
return Variable(dims, data, attrs=attrs, encoding=encoding)
509506

xarray/tests/test_dask.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,3 +1607,11 @@ def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds):
16071607
assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da)
16081608
assert_equal(map_da.astype(map_da.dtype), map_da)
16091609
assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy)
1610+
1611+
1612+
def test_optimize():
1613+
# https://github.com/pydata/xarray/issues/3698
1614+
a = dask.array.ones((10, 4), chunks=(5, 2))
1615+
arr = xr.DataArray(a).chunk(5)
1616+
(arr2,) = dask.optimize(arr)
1617+
arr2.compute()

0 commit comments

Comments
 (0)