From 12fe06979294dfd754fa5414d79334a5cae0387f Mon Sep 17 00:00:00 2001 From: Kai Muehlbauer Date: Tue, 31 Mar 2020 15:16:52 +0200 Subject: [PATCH 01/20] FIX: correct dask array handling in _calc_idxminmax --- xarray/core/computation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 6cf4178b5bf..65e8a09a667 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1384,8 +1384,9 @@ def _calc_idxminmax( coordarray = array[dim] # Handle dask arrays. - if isinstance(array, dask_array_type): - res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype) + if isinstance(array.data, dask_array_type): + res = array.map_blocks(lambda a, b: a[b], coordarray, indx, + dtype=indx.dtype).compute() else: res = coordarray[ indx, From 59a74d69c942126f11d847622b961f03a6145cb8 Mon Sep 17 00:00:00 2001 From: Kai Muehlbauer Date: Tue, 31 Mar 2020 15:31:42 +0200 Subject: [PATCH 02/20] FIX: remove unneeded import, reformat via black --- xarray/core/computation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 65e8a09a667..92b35339f7b 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -26,7 +26,6 @@ from . import dtypes, duck_array_ops, utils from .alignment import deep_align from .merge import merge_coordinates_without_align -from .nanops import dask_array from .options import OPTIONS from .pycompat import dask_array_type from .utils import is_dict_like @@ -1385,8 +1384,9 @@ def _calc_idxminmax( # Handle dask arrays. if isinstance(array.data, dask_array_type): - res = array.map_blocks(lambda a, b: a[b], coordarray, indx, - dtype=indx.dtype).compute() + res = array.map_blocks( + lambda a, b: a[b], coordarray, indx, dtype=indx.dtype + ).compute() else: res = coordarray[ indx, From 7369007fd364ef6a37ff0e3f00011f6f8f45e2ea Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 31 Mar 2020 07:59:32 -0600 Subject: [PATCH 03/20] fix idxmax, idxmin with dask arrays --- xarray/core/computation.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 92b35339f7b..e0c297bca0f 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1384,21 +1384,20 @@ def _calc_idxminmax( # Handle dask arrays. if isinstance(array.data, dask_array_type): - res = array.map_blocks( - lambda a, b: a[b], coordarray, indx, dtype=indx.dtype - ).compute() + res = indx.copy( + data=indx.data.map_blocks( + lambda ind, coord: coord[(ind,)], coordarray, dtype=coordarray.dtype + ) + ) else: - res = coordarray[ - indx, - ] + res = coordarray[(indx,)] + # The dim is gone but we need to remove the corresponding coordinate. + del res.coords[dim] if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them res = res.where(~allna, fill_value) - # The dim is gone but we need to remove the corresponding coordinate. - del res.coords[dim] - # Copy attributes from argmin/argmax, if any res.attrs = indx.attrs From 9a862fc0ff5d61ce1f528b236246d33e7f7292f1 Mon Sep 17 00:00:00 2001 From: Kai Muehlbauer Date: Fri, 3 Apr 2020 12:57:17 +0200 Subject: [PATCH 04/20] FIX: use array[dim].data in `_calc_idxminmax` as per @keewis suggestion, attach dim name to result --- xarray/core/computation.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index e0c297bca0f..bf709d7e4ad 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1379,18 +1379,19 @@ def _calc_idxminmax( # This will run argmin or argmax. indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna) - # Get the coordinate we want. - coordarray = array[dim] - # Handle dask arrays. if isinstance(array.data, dask_array_type): res = indx.copy( data=indx.data.map_blocks( - lambda ind, coord: coord[(ind,)], coordarray, dtype=coordarray.dtype + lambda ind, coord: coord[(ind,)], + array[dim].data, + dtype=array[dim].dtype, ) ) + # we need to attach back the dim name + res.name = dim else: - res = coordarray[(indx,)] + res = array[dim][(indx,)] # The dim is gone but we need to remove the corresponding coordinate. del res.coords[dim] From b154c8aca8e0f6a22ef448cc0b348ef37ff664b8 Mon Sep 17 00:00:00 2001 From: Kai Muehlbauer Date: Fri, 3 Apr 2020 12:58:13 +0200 Subject: [PATCH 05/20] ADD: add dask tests to `idxmin`/`idxmax` dataarray tests --- xarray/tests/test_dataarray.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 5b3e122bf72..4bea4272eb1 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4193,7 +4193,6 @@ def test_rank(self): assert_equal(y.rank("z", pct=True), y) @pytest.mark.parametrize("use_dask", [True, False]) - @pytest.mark.parametrize("use_datetime", [True, False]) def test_polyfit(self, use_dask, use_datetime): if use_dask and not has_dask: pytest.skip("requires dask") @@ -4490,11 +4489,21 @@ def test_argmax(self, x, minindex, maxindex, nanindex): assert_identical(result2, expected2) - def test_idxmin(self, x, minindex, maxindex, nanindex): - ar0 = xr.DataArray( + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): + if use_dask and not has_dask: + pytest.skip("requires dask") + ar0_raw = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs, ) + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + + print(ar0) + # dim doesn't exist with pytest.raises(KeyError): ar0.idxmin(dim="spam") @@ -4586,11 +4595,19 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): result7 = ar0.idxmin(fill_value=-1j) assert_identical(result7, expected7) - def test_idxmax(self, x, minindex, maxindex, nanindex): - ar0 = xr.DataArray( + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): + if use_dask and not has_dask: + pytest.skip("requires dask") + ar0_raw = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs, ) + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + # dim doesn't exist with pytest.raises(KeyError): ar0.idxmax(dim="spam") From bf5530b4995584d51f121861d055ce17a97c4906 Mon Sep 17 00:00:00 2001 From: Kai Muehlbauer Date: Fri, 3 Apr 2020 13:06:29 +0200 Subject: [PATCH 06/20] FIX: add back fixture line removed by accident --- xarray/tests/test_dataarray.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 4bea4272eb1..a8dc3805338 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4193,6 +4193,7 @@ def test_rank(self): assert_equal(y.rank("z", pct=True), y) @pytest.mark.parametrize("use_dask", [True, False]) + @pytest.mark.parametrize("use_datetime", [True, False]) def test_polyfit(self, use_dask, use_datetime): if use_dask and not has_dask: pytest.skip("requires dask") From a60ae8950564366eac434bbe8e90db8c65b48adc Mon Sep 17 00:00:00 2001 From: Kai Muehlbauer Date: Fri, 3 Apr 2020 14:48:58 +0200 Subject: [PATCH 07/20] ADD: complete dask handling in `idxmin`/`idxmax` tests in test_dataarray, xfail dask tests for dtype dateime64 (M) --- xarray/tests/test_dataarray.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index a8dc3805338..cd69e1bbea0 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4494,6 +4494,8 @@ def test_argmax(self, x, minindex, maxindex, nanindex): def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): if use_dask and not has_dask: pytest.skip("requires dask") + if use_dask & (x.dtype.kind == "M"): + pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)") ar0_raw = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs, ) @@ -4503,8 +4505,6 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): else: ar0 = ar0_raw - print(ar0) - # dim doesn't exist with pytest.raises(KeyError): ar0.idxmin(dim="spam") @@ -4600,6 +4600,8 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): if use_dask and not has_dask: pytest.skip("requires dask") + if use_dask & (x.dtype.kind == "M"): + pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)") ar0_raw = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs, ) @@ -4928,14 +4930,24 @@ def test_argmax(self, x, minindex, maxindex, nanindex): assert_identical(result3, expected2) - def test_idxmin(self, x, minindex, maxindex, nanindex): - ar0 = xr.DataArray( + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask & (x.dtype.kind == "M"): + pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)") + ar0_raw = xr.DataArray( x, dims=["y", "x"], coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, attrs=self.attrs, ) + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + assert_identical(ar0, ar0) # No dimension specified @@ -5035,14 +5047,24 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): result7 = ar0.idxmin(dim="x", fill_value=-5j) assert_identical(result7, expected7) - def test_idxmax(self, x, minindex, maxindex, nanindex): - ar0 = xr.DataArray( + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask & (x.dtype.kind == "M"): + pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)") + ar0_raw = xr.DataArray( x, dims=["y", "x"], coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, attrs=self.attrs, ) + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + # No dimension specified with pytest.raises(ValueError): ar0.idxmax() From e7c42c75767bb65a10950d51b1124812488963c3 Mon Sep 17 00:00:00 2001 From: Kai Muehlbauer Date: Fri, 3 Apr 2020 15:48:58 +0200 Subject: [PATCH 08/20] ADD: add "support dask handling for idxmin/idxmax" in whats-new.rst --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c70dfd4f3f6..173dd07c3d4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -41,6 +41,9 @@ New Features - Implement :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`) By `Todd Jennings `_ +- Support dask handling for :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, + :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:pull:`3922`) + By `Kai Mühlbauer `_. Bug fixes From 5c233271d44c3c0a31bf3e836f83fc6f3e7a6eb9 Mon Sep 17 00:00:00 2001 From: Kai Muehlbauer Date: Tue, 14 Apr 2020 07:18:29 +0200 Subject: [PATCH 09/20] MIN: reintroduce changes added by #3953 --- xarray/tests/test_dataarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 0e828226558..c253d78c6ef 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4537,7 +4537,7 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): if use_dask & (x.dtype.kind == "M"): pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)") ar0_raw = xr.DataArray( - x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs, + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs ) if use_dask: @@ -4643,7 +4643,7 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): if use_dask & (x.dtype.kind == "M"): pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)") ar0_raw = xr.DataArray( - x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs, + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs ) if use_dask: From 081ff6f5651ca7618bbc1665e1f0b6999da3e6bc Mon Sep 17 00:00:00 2001 From: Kai Muehlbauer Date: Tue, 14 Apr 2020 08:06:55 +0200 Subject: [PATCH 10/20] MIN: change if-clause to use `and` instead of `&` as per review-comment --- xarray/tests/test_dataarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index c253d78c6ef..781a23745b7 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4974,7 +4974,7 @@ def test_argmax(self, x, minindex, maxindex, nanindex): def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): if use_dask and not has_dask: pytest.skip("requires dask") - if use_dask & (x.dtype.kind == "M"): + if use_dask and x.dtype.kind == "M": pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)") ar0_raw = xr.DataArray( x, @@ -5091,7 +5091,7 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): if use_dask and not has_dask: pytest.skip("requires dask") - if use_dask & (x.dtype.kind == "M"): + if use_dask and x.dtype.kind == "M": pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)") ar0_raw = xr.DataArray( x, From 26cc8f3d8580a768f13ffda678d482476d3dd2b8 Mon Sep 17 00:00:00 2001 From: Kai Muehlbauer Date: Tue, 14 Apr 2020 08:22:47 +0200 Subject: [PATCH 11/20] MIN: change if-clause to use `and` instead of `&` as per review-comment --- xarray/tests/test_dataarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 781a23745b7..88c39fd06d6 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4534,7 +4534,7 @@ def test_argmax(self, x, minindex, maxindex, nanindex): def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): if use_dask and not has_dask: pytest.skip("requires dask") - if use_dask & (x.dtype.kind == "M"): + if use_dask and x.dtype.kind == "M": pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)") ar0_raw = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs @@ -4640,7 +4640,7 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): if use_dask and not has_dask: pytest.skip("requires dask") - if use_dask & (x.dtype.kind == "M"): + if use_dask and x.dtype.kind == "M": pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)") ar0_raw = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs From 613f4e61ba65cfee03fbada7135bf1e8a9420b49 Mon Sep 17 00:00:00 2001 From: Kai Muehlbauer Date: Tue, 14 Apr 2020 09:18:47 +0200 Subject: [PATCH 12/20] WIP: remove dask handling entirely for debugging purposes --- xarray/core/computation.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index bf709d7e4ad..37ee8feaf24 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1379,21 +1379,9 @@ def _calc_idxminmax( # This will run argmin or argmax. indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna) - # Handle dask arrays. - if isinstance(array.data, dask_array_type): - res = indx.copy( - data=indx.data.map_blocks( - lambda ind, coord: coord[(ind,)], - array[dim].data, - dtype=array[dim].dtype, - ) - ) - # we need to attach back the dim name - res.name = dim - else: - res = array[dim][(indx,)] - # The dim is gone but we need to remove the corresponding coordinate. - del res.coords[dim] + res = array[dim][(indx,)] + # The dim is gone but we need to remove the corresponding coordinate. + del res.coords[dim] if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them From 8a9549222500504b52596a484d36c947deeaf6da Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 15 Apr 2020 07:07:36 -0600 Subject: [PATCH 13/20] Test for dask computes --- xarray/tests/test_dataarray.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 88c39fd06d6..4952f9983fc 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -34,6 +34,8 @@ source_ndarray, ) +from .test_dask import raise_if_dask_computes + class TestDataArray: @pytest.fixture(autouse=True) @@ -5136,15 +5138,18 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): expected0.name = "x" # Default fill value (NaN) - result0 = ar0.idxmax(dim="x") + with raise_if_dask_computes(): + result0 = ar0.idxmax(dim="x") assert_identical(result0, expected0) # Manually specify NaN fill_value - result1 = ar0.idxmax(dim="x", fill_value=np.NaN) + with raise_if_dask_computes(): + result1 = ar0.idxmax(dim="x", fill_value=np.NaN) assert_identical(result1, expected0) # keep_attrs - result2 = ar0.idxmax(dim="x", keep_attrs=True) + with raise_if_dask_computes(): + result2 = ar0.idxmax(dim="x", keep_attrs=True) expected2 = expected0.copy() expected2.attrs = self.attrs assert_identical(result2, expected2) @@ -5162,11 +5167,13 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): expected3.name = "x" expected3.attrs = {} - result3 = ar0.idxmax(dim="x", skipna=False) + with raise_if_dask_computes(): + result3 = ar0.idxmax(dim="x", skipna=False) assert_identical(result3, expected3) # fill_value should be ignored with skipna=False - result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j) + with raise_if_dask_computes(): + result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j) assert_identical(result4, expected3) # Float fill_value @@ -5178,7 +5185,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): expected5 = xr.concat(expected5, dim="y") expected5.name = "x" - result5 = ar0.idxmax(dim="x", fill_value=-1.1) + with raise_if_dask_computes(): + result5 = ar0.idxmax(dim="x", fill_value=-1.1) assert_identical(result5, expected5) # Integer fill_value @@ -5190,7 +5198,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): expected6 = xr.concat(expected6, dim="y") expected6.name = "x" - result6 = ar0.idxmax(dim="x", fill_value=-1) + with raise_if_dask_computes(): + result6 = ar0.idxmax(dim="x", fill_value=-1) assert_identical(result6, expected6) # Complex fill_value @@ -5202,7 +5211,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): expected7 = xr.concat(expected7, dim="y") expected7.name = "x" - result7 = ar0.idxmax(dim="x", fill_value=-5j) + with raise_if_dask_computes(): + result7 = ar0.idxmax(dim="x", fill_value=-5j) assert_identical(result7, expected7) From 8cf4f3e666000dc28671e1b6f53566389c3625d4 Mon Sep 17 00:00:00 2001 From: Kai Muehlbauer Date: Thu, 16 Apr 2020 08:39:01 +0200 Subject: [PATCH 14/20] WIP: re-add dask handling (map_blocks-approach), add `with raise_if_dask_computes()` context to idxmin-tests --- xarray/core/computation.py | 18 +++++++++++++++--- xarray/tests/test_dataarray.py | 26 +++++++++++++++++--------- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 37ee8feaf24..bf709d7e4ad 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1379,9 +1379,21 @@ def _calc_idxminmax( # This will run argmin or argmax. indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna) - res = array[dim][(indx,)] - # The dim is gone but we need to remove the corresponding coordinate. - del res.coords[dim] + # Handle dask arrays. + if isinstance(array.data, dask_array_type): + res = indx.copy( + data=indx.data.map_blocks( + lambda ind, coord: coord[(ind,)], + array[dim].data, + dtype=array[dim].dtype, + ) + ) + # we need to attach back the dim name + res.name = dim + else: + res = array[dim][(indx,)] + # The dim is gone but we need to remove the corresponding coordinate. + del res.coords[dim] if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 4952f9983fc..1a30f6644ff 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -5020,15 +5020,18 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): expected0.name = "x" # Default fill value (NaN) - result0 = ar0.idxmin(dim="x") + with raise_if_dask_computes(): + result0 = ar0.idxmin(dim="x") assert_identical(result0, expected0) # Manually specify NaN fill_value - result1 = ar0.idxmin(dim="x", fill_value=np.NaN) + with raise_if_dask_computes(): + result1 = ar0.idxmin(dim="x", fill_value=np.NaN) assert_identical(result1, expected0) # keep_attrs - result2 = ar0.idxmin(dim="x", keep_attrs=True) + with raise_if_dask_computes(): + result2 = ar0.idxmin(dim="x", keep_attrs=True) expected2 = expected0.copy() expected2.attrs = self.attrs assert_identical(result2, expected2) @@ -5046,11 +5049,13 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): expected3.name = "x" expected3.attrs = {} - result3 = ar0.idxmin(dim="x", skipna=False) + with raise_if_dask_computes(): + result3 = ar0.idxmin(dim="x", skipna=False) assert_identical(result3, expected3) # fill_value should be ignored with skipna=False - result4 = ar0.idxmin(dim="x", skipna=False, fill_value=-100j) + with raise_if_dask_computes(): + result4 = ar0.idxmin(dim="x", skipna=False, fill_value=-100j) assert_identical(result4, expected3) # Float fill_value @@ -5062,7 +5067,8 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): expected5 = xr.concat(expected5, dim="y") expected5.name = "x" - result5 = ar0.idxmin(dim="x", fill_value=-1.1) + with raise_if_dask_computes(): + result5 = ar0.idxmin(dim="x", fill_value=-1.1) assert_identical(result5, expected5) # Integer fill_value @@ -5074,7 +5080,8 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): expected6 = xr.concat(expected6, dim="y") expected6.name = "x" - result6 = ar0.idxmin(dim="x", fill_value=-1) + with raise_if_dask_computes(): + result6 = ar0.idxmin(dim="x", fill_value=-1) assert_identical(result6, expected6) # Complex fill_value @@ -5086,7 +5093,8 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): expected7 = xr.concat(expected7, dim="y") expected7.name = "x" - result7 = ar0.idxmin(dim="x", fill_value=-5j) + with raise_if_dask_computes(): + result7 = ar0.idxmin(dim="x", fill_value=-5j) assert_identical(result7, expected7) @pytest.mark.parametrize("use_dask", [True, False]) @@ -5172,7 +5180,7 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): assert_identical(result3, expected3) # fill_value should be ignored with skipna=False - with raise_if_dask_computes(): + with raise_if_dask_computes(1): result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j) assert_identical(result4, expected3) From 4417a35f92c50decac9e9a7064c35a9d8460aacb Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 18 Apr 2020 13:29:21 -0600 Subject: [PATCH 15/20] Use dask indexing instead of map_blocks. --- xarray/core/computation.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index bf709d7e4ad..aabdb067a88 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1381,13 +1381,10 @@ def _calc_idxminmax( # Handle dask arrays. if isinstance(array.data, dask_array_type): - res = indx.copy( - data=indx.data.map_blocks( - lambda ind, coord: coord[(ind,)], - array[dim].data, - dtype=array[dim].dtype, - ) - ) + import dask.array + + dask_coord = dask.array.from_array(array[dim].data, chunks=(1,)) + res = indx.copy(data=dask_coord[(indx.data,)]) # we need to attach back the dim name res.name = dim else: From 43384c3662a50e1520f6d6caf0027250242c8040 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 18 Apr 2020 13:40:09 -0600 Subject: [PATCH 16/20] Better chunk choice. --- xarray/core/computation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index aabdb067a88..1937b31015e 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1383,7 +1383,8 @@ def _calc_idxminmax( if isinstance(array.data, dask_array_type): import dask.array - dask_coord = dask.array.from_array(array[dim].data, chunks=(1,)) + chunks = dict(zip(array.dims, array.chunks)) + dask_coord = dask.array.from_array(array[dim].data, chunks=chunks[dim]) res = indx.copy(data=dask_coord[(indx.data,)]) # we need to attach back the dim name res.name = dim From 58901b9da821a04f2ec085577cb916c4d67f6f50 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 9 May 2020 07:33:11 -0600 Subject: [PATCH 17/20] Return -1 for _nan_argminmax_object if all NaNs along dim --- xarray/core/nanops.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index f9989c2c8c9..b373216bc99 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -47,17 +47,12 @@ def _maybe_null_out(result, axis, mask, min_count=1): def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): """ In house nanargmin, nanargmax for object arrays. Always return integer - type + type. Returns -1 for all NaN values along axis. """ valid_count = count(value, axis=axis) value = fillna(value, fill_value) data = _dask_or_eager_func(func)(value, axis=axis, **kwargs) - - # TODO This will evaluate dask arrays and might be costly. - if (valid_count == 0).any(): - raise ValueError("All-NaN slice encountered") - - return data + return where_method(data, valid_count > 0, -1) def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs): From cfc98da7b2f5e3a87c3cf2c80a36b17072d20bcb Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 9 May 2020 07:51:56 -0600 Subject: [PATCH 18/20] Revert "Return -1 for _nan_argminmax_object if all NaNs along dim" This reverts commit 58901b9da821a04f2ec085577cb916c4d67f6f50. --- xarray/core/nanops.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index b373216bc99..f9989c2c8c9 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -47,12 +47,17 @@ def _maybe_null_out(result, axis, mask, min_count=1): def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): """ In house nanargmin, nanargmax for object arrays. Always return integer - type. Returns -1 for all NaN values along axis. + type """ valid_count = count(value, axis=axis) value = fillna(value, fill_value) data = _dask_or_eager_func(func)(value, axis=axis, **kwargs) - return where_method(data, valid_count > 0, -1) + + # TODO This will evaluate dask arrays and might be costly. + if (valid_count == 0).any(): + raise ValueError("All-NaN slice encountered") + + return data def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs): From c1324bfafb5dde650cb2e886d1882c8e30ea3f01 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 9 May 2020 07:53:28 -0600 Subject: [PATCH 19/20] Raise error for object arrays --- xarray/core/computation.py | 3 +++ xarray/tests/test_dataarray.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 1937b31015e..2f6da370ef9 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1371,6 +1371,9 @@ def _calc_idxminmax( # These are dtypes with NaN values argmin and argmax can handle na_dtypes = "cfO" + if array.dtype.kind == "O": + raise ValueError("idxmin, idxmax do not support object arrays yet.") + if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Need to skip NaN values since argmin and argmax can't handle them allna = array.isnull().all(dim) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 1a30f6644ff..555d454e85d 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -5103,6 +5103,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): pytest.skip("requires dask") if use_dask and x.dtype.kind == "M": pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)") + if x.dtype.kind == "O": + pytest.xfail("idxmax, idxmin not implemented for object arrays yet.") ar0_raw = xr.DataArray( x, dims=["y", "x"], From 525118bea981dd5e78516e03e693d5d771ef4a8f Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 9 May 2020 08:02:42 -0600 Subject: [PATCH 20/20] No error for object arrays. Instead expect 1 compute in tests. --- xarray/core/computation.py | 3 --- xarray/tests/test_dataarray.py | 46 +++++++++++++++++++++------------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 2f6da370ef9..1937b31015e 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1371,9 +1371,6 @@ def _calc_idxminmax( # These are dtypes with NaN values argmin and argmax can handle na_dtypes = "cfO" - if array.dtype.kind == "O": - raise ValueError("idxmin, idxmax do not support object arrays yet.") - if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Need to skip NaN values since argmin and argmax can't handle them allna = array.isnull().all(dim) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 555d454e85d..6eeaed66f9f 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4978,6 +4978,13 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): pytest.skip("requires dask") if use_dask and x.dtype.kind == "M": pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)") + + if x.dtype.kind == "O": + # TODO: nanops._nan_argminmax_object computes once to check for all-NaN slices. + max_computes = 1 + else: + max_computes = 0 + ar0_raw = xr.DataArray( x, dims=["y", "x"], @@ -5020,17 +5027,17 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): expected0.name = "x" # Default fill value (NaN) - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result0 = ar0.idxmin(dim="x") assert_identical(result0, expected0) # Manually specify NaN fill_value - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result1 = ar0.idxmin(dim="x", fill_value=np.NaN) assert_identical(result1, expected0) # keep_attrs - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result2 = ar0.idxmin(dim="x", keep_attrs=True) expected2 = expected0.copy() expected2.attrs = self.attrs @@ -5049,12 +5056,12 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): expected3.name = "x" expected3.attrs = {} - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result3 = ar0.idxmin(dim="x", skipna=False) assert_identical(result3, expected3) # fill_value should be ignored with skipna=False - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result4 = ar0.idxmin(dim="x", skipna=False, fill_value=-100j) assert_identical(result4, expected3) @@ -5067,7 +5074,7 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): expected5 = xr.concat(expected5, dim="y") expected5.name = "x" - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result5 = ar0.idxmin(dim="x", fill_value=-1.1) assert_identical(result5, expected5) @@ -5080,7 +5087,7 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): expected6 = xr.concat(expected6, dim="y") expected6.name = "x" - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result6 = ar0.idxmin(dim="x", fill_value=-1) assert_identical(result6, expected6) @@ -5093,7 +5100,7 @@ def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): expected7 = xr.concat(expected7, dim="y") expected7.name = "x" - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result7 = ar0.idxmin(dim="x", fill_value=-5j) assert_identical(result7, expected7) @@ -5103,8 +5110,13 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): pytest.skip("requires dask") if use_dask and x.dtype.kind == "M": pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)") + if x.dtype.kind == "O": - pytest.xfail("idxmax, idxmin not implemented for object arrays yet.") + # TODO: nanops._nan_argminmax_object computes once to check for all-NaN slices. + max_computes = 1 + else: + max_computes = 0 + ar0_raw = xr.DataArray( x, dims=["y", "x"], @@ -5148,17 +5160,17 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): expected0.name = "x" # Default fill value (NaN) - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result0 = ar0.idxmax(dim="x") assert_identical(result0, expected0) # Manually specify NaN fill_value - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result1 = ar0.idxmax(dim="x", fill_value=np.NaN) assert_identical(result1, expected0) # keep_attrs - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result2 = ar0.idxmax(dim="x", keep_attrs=True) expected2 = expected0.copy() expected2.attrs = self.attrs @@ -5177,12 +5189,12 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): expected3.name = "x" expected3.attrs = {} - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result3 = ar0.idxmax(dim="x", skipna=False) assert_identical(result3, expected3) # fill_value should be ignored with skipna=False - with raise_if_dask_computes(1): + with raise_if_dask_computes(max_computes=max_computes): result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j) assert_identical(result4, expected3) @@ -5195,7 +5207,7 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): expected5 = xr.concat(expected5, dim="y") expected5.name = "x" - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result5 = ar0.idxmax(dim="x", fill_value=-1.1) assert_identical(result5, expected5) @@ -5208,7 +5220,7 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): expected6 = xr.concat(expected6, dim="y") expected6.name = "x" - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result6 = ar0.idxmax(dim="x", fill_value=-1) assert_identical(result6, expected6) @@ -5221,7 +5233,7 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): expected7 = xr.concat(expected7, dim="y") expected7.name = "x" - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=max_computes): result7 = ar0.idxmax(dim="x", fill_value=-5j) assert_identical(result7, expected7)