From 39bb24019dcb69f22cd13ad41e82e302db4cdb39 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Fri, 18 Sep 2020 16:26:19 -0500 Subject: [PATCH 1/3] Fixed dask.optimize on datasets Another attempt to fix #3698. The issue with my fix in is that we hit `Variable._dask_finalize` in both `dask.optimize` and `dask.persist`. We want to do the culling of unnecessary tasks (`test_persist_Dataset`) but only in the persist case, not optimize (`test_optimize`). --- doc/whats-new.rst | 2 ++ xarray/core/dataset.py | 11 ++++++++++- xarray/core/variable.py | 3 --- xarray/tests/test_dask.py | 8 ++++++++ 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d8b1fc2fba9..8ac002e14f3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -89,6 +89,8 @@ Bug fixes - Fix indexing with datetime64 scalars with pandas 1.1 (:issue:`4283`). By `Stephan Hoyer `_ and `Justus Magin `_. + - Fix ``dask.optimize`` on ``DataArray`` producing an invalid Dask task graph (:issue:`3698`) + By `Tom Augspurger `_ Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 825d2044a12..c4f22197bd9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -777,10 +777,19 @@ def _dask_postcompute(results, info, *args): @staticmethod def _dask_postpersist(dsk, info, *args): variables = {} + # postpersist is called in both dask.optimize and dask.persist + # When persisting, we want to filter out unrelated keys for + # each Variable's task graph. + is_persist = len(dsk) == len(info) for is_dask, k, v in info: if is_dask: func, args2 = v - result = func(dsk, *args2) + if is_persist: + name = args2[1][0] + dsk2 = {k: v for k, v in dsk.items() if k[0] == name} + else: + dsk2 = dsk + result = func(dsk2, *args2) else: result = v variables[k] = result diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 6de00ee882a..c55e61cb816 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -501,9 +501,6 @@ def __dask_postpersist__(self): @staticmethod def _dask_finalize(results, array_func, array_args, dims, attrs, encoding): - if isinstance(results, dict): # persist case - name = array_args[0] - results = {k: v for k, v in results.items() if k[0] == name} data = array_func(results, *array_args) return Variable(dims, data, attrs=attrs, encoding=encoding) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 46685a29a47..7d664aca3e4 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1607,3 +1607,11 @@ def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds): assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da) assert_equal(map_da.astype(map_da.dtype), map_da) assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy) + + +def test_optimize(): + # https://github.com/pydata/xarray/issues/3698 + a = dask.array.ones((10, 4), chunks=(5, 2)) + arr = xr.DataArray(a).chunk(5) + (arr2,) = dask.optimize(arr) + arr2.compute() From 42a52f3c524b916851a6b1694b8e3ff2f5638863 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 20 Sep 2020 00:09:11 +0000 Subject: [PATCH 2/3] Update whats-new.rst --- doc/whats-new.rst | 3 --- 1 file changed, 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f036ef0809f..e1053bf6ec3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -101,9 +101,6 @@ Bug fixes By `Jens Svensmark `_ - Fix incorrect legend labels for :py:meth:`Dataset.plot.scatter` (:issue:`4126`). By `Peter Hausamann `_. -- Fix indexing with datetime64 scalars with pandas 1.1 (:issue:`4283`). - By `Stephan Hoyer `_ and - `Justus Magin `_. - Fix ``dask.optimize`` on ``DataArray`` producing an invalid Dask task graph (:issue:`3698`) By `Tom Augspurger `_ - Fix ``pip install .`` when no ``.git`` directory exists; namely when the xarray source From 8c501dfb560635465f769ccef79cfff9db1fd9d7 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 19 Sep 2020 19:27:41 -0700 Subject: [PATCH 3/3] Update doc/whats-new.rst --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e1053bf6ec3..82f51a1beec 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -101,7 +101,7 @@ Bug fixes By `Jens Svensmark `_ - Fix incorrect legend labels for :py:meth:`Dataset.plot.scatter` (:issue:`4126`). By `Peter Hausamann `_. - - Fix ``dask.optimize`` on ``DataArray`` producing an invalid Dask task graph (:issue:`3698`) +- Fix ``dask.optimize`` on ``DataArray`` producing an invalid Dask task graph (:issue:`3698`) By `Tom Augspurger `_ - Fix ``pip install .`` when no ``.git`` directory exists; namely when the xarray source directory has been rsync'ed by PyCharm Professional for a remote deployment over SSH.