Skip to content

Commit 9a8a62b

Browse files
Fix optimize for chunked DataArray (#4432)
Previously we generated in invalidate Dask task graph, becuase the lines removed here dropped keys that were referenced elsewhere in the task graph. The original implementation had a comment indicating that this was to cull: https://github.com/pydata/xarray/blame/502a988ad5b87b9f3aeec3033bf55c71272e1053/xarray/core/variable.py#L384 Just spot-checking things, I think we're OK here though. Something like `dask.visualize(arr[[0]], optimize_graph=True)` indicates that we're OK. Closes #3698 Co-authored-by: Maximilian Roos <[email protected]>
1 parent b0d8d93 commit 9a8a62b

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,13 @@ Bug fixes
8484
- Fix `KeyError` when doing linear interpolation to an nd `DataArray`
8585
that contains NaNs (:pull:`4233`).
8686
By `Jens Svensmark <https://github.com/jenssss>`_
87+
- Fix ``dask.optimize`` on ``DataArray`` producing an invalid Dask task graph (:issue:`3698`)
8788
- Fix incorrect legend labels for :py:meth:`Dataset.plot.scatter` (:issue:`4126`).
8889
By `Peter Hausamann <https://github.com/phausamann>`_.
8990
- Fix indexing with datetime64 scalars with pandas 1.1 (:issue:`4283`).
9091
By `Stephan Hoyer <https://github.com/shoyer>`_ and
9192
`Justus Magin <https://github.com/keewis>`_.
93+
9294

9395
Documentation
9496
~~~~~~~~~~~~~

xarray/core/variable.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,6 @@ def __dask_postpersist__(self):
501501

502502
@staticmethod
503503
def _dask_finalize(results, array_func, array_args, dims, attrs, encoding):
504-
if isinstance(results, dict): # persist case
505-
name = array_args[0]
506-
results = {k: v for k, v in results.items() if k[0] == name}
507504
data = array_func(results, *array_args)
508505
return Variable(dims, data, attrs=attrs, encoding=encoding)
509506

xarray/tests/test_dask.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,3 +1607,10 @@ def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds):
16071607
assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da)
16081608
assert_equal(map_da.astype(map_da.dtype), map_da)
16091609
assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy)
1610+
1611+
1612+
def test_optimize():
1613+
a = dask.array.ones((10, 5), chunks=(1, 3))
1614+
arr = xr.DataArray(a).chunk(5)
1615+
(arr2,) = dask.optimize(arr)
1616+
arr2.compute()

0 commit comments

Comments
 (0)