diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a1d52b28ed5..4fbeb7033c7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -53,6 +53,9 @@ New Features - Implement :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`) By `Todd Jennings `_ +- Support dask handling for :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, + :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:pull:`3922`) + By `Kai Mühlbauer `_. - More support for unit aware arrays with pint (:pull:`3643`) By `Justus Magin `_. - Support overriding existing variables in ``to_zarr()`` with ``mode='a'`` even diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a3723ea9db9..28bf818e4a3 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -26,7 +26,6 @@ from . import dtypes, duck_array_ops, utils from .alignment import deep_align from .merge import merge_coordinates_without_align -from .nanops import dask_array from .options import OPTIONS from .pycompat import dask_array_type from .utils import is_dict_like @@ -1380,24 +1379,24 @@ def _calc_idxminmax( # This will run argmin or argmax. indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna) - # Get the coordinate we want. - coordarray = array[dim] - # Handle dask arrays. - if isinstance(array, dask_array_type): - res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype) + if isinstance(array.data, dask_array_type): + import dask.array + + chunks = dict(zip(array.dims, array.chunks)) + dask_coord = dask.array.from_array(array[dim].data, chunks=chunks[dim]) + res = indx.copy(data=dask_coord[(indx.data,)]) + # we need to attach back the dim name + res.name = dim else: - res = coordarray[ - indx, - ] + res = array[dim][(indx,)] + # The dim is gone but we need to remove the corresponding coordinate. + del res.coords[dim] if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them res = res.where(~allna, fill_value) - # The dim is gone but we need to remove the corresponding coordinate. - del res.coords[dim] - # Copy attributes from argmin/argmax, if any res.attrs = indx.attrs diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 6984d5361d2..a01234616a4 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -34,6 +34,8 @@ source_ndarray, ) +from .test_dask import raise_if_dask_computes + class TestDataArray: @pytest.fixture(autouse=True) @@ -4524,11 +4526,21 @@ def test_argmax(self, x, minindex, maxindex, nanindex): assert_identical(result2, expected2) - def test_idxmin(self, x, minindex, maxindex, nanindex): - ar0 = xr.DataArray( + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask and x.dtype.kind == "M": + pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)") + ar0_raw = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs ) + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + # dim doesn't exist with pytest.raises(KeyError): ar0.idxmin(dim="spam") @@ -4620,11 +4632,21 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): result7 = ar0.idxmin(fill_value=-1j) assert_identical(result7, expected7) - def test_idxmax(self, x, minindex, maxindex, nanindex): - ar0 = xr.DataArray( + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask and x.dtype.kind == "M": + pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)") + ar0_raw = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs ) + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + # dim doesn't exist with pytest.raises(KeyError): ar0.idxmax(dim="spam") @@ -4944,14 +4966,31 @@ def test_argmax(self, x, minindex, maxindex, nanindex): assert_identical(result3, expected2) - def test_idxmin(self, x, minindex, maxindex, nanindex): - ar0 = xr.DataArray( + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask and x.dtype.kind == "M": + pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)") + + if x.dtype.kind == "O": + # TODO: nanops._nan_argminmax_object computes once to check for all-NaN slices. + max_computes = 1 + else: + max_computes = 0 + + ar0_raw = xr.DataArray( x, dims=["y", "x"], coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, attrs=self.attrs, ) + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + assert_identical(ar0, ar0) # No dimension specified @@ -4982,15 +5021,18 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): expected0.name = "x" # Default fill value (NaN) - result0 = ar0.idxmin(dim="x") + with raise_if_dask_computes(max_computes=max_computes): + result0 = ar0.idxmin(dim="x") assert_identical(result0, expected0) # Manually specify NaN fill_value - result1 = ar0.idxmin(dim="x", fill_value=np.NaN) + with raise_if_dask_computes(max_computes=max_computes): + result1 = ar0.idxmin(dim="x", fill_value=np.NaN) assert_identical(result1, expected0) # keep_attrs - result2 = ar0.idxmin(dim="x", keep_attrs=True) + with raise_if_dask_computes(max_computes=max_computes): + result2 = ar0.idxmin(dim="x", keep_attrs=True) expected2 = expected0.copy() expected2.attrs = self.attrs assert_identical(result2, expected2) @@ -5008,11 +5050,13 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): expected3.name = "x" expected3.attrs = {} - result3 = ar0.idxmin(dim="x", skipna=False) + with raise_if_dask_computes(max_computes=max_computes): + result3 = ar0.idxmin(dim="x", skipna=False) assert_identical(result3, expected3) # fill_value should be ignored with skipna=False - result4 = ar0.idxmin(dim="x", skipna=False, fill_value=-100j) + with raise_if_dask_computes(max_computes=max_computes): + result4 = ar0.idxmin(dim="x", skipna=False, fill_value=-100j) assert_identical(result4, expected3) # Float fill_value @@ -5024,7 +5068,8 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): expected5 = xr.concat(expected5, dim="y") expected5.name = "x" - result5 = ar0.idxmin(dim="x", fill_value=-1.1) + with raise_if_dask_computes(max_computes=max_computes): + result5 = ar0.idxmin(dim="x", fill_value=-1.1) assert_identical(result5, expected5) # Integer fill_value @@ -5036,7 +5081,8 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): expected6 = xr.concat(expected6, dim="y") expected6.name = "x" - result6 = ar0.idxmin(dim="x", fill_value=-1) + with raise_if_dask_computes(max_computes=max_computes): + result6 = ar0.idxmin(dim="x", fill_value=-1) assert_identical(result6, expected6) # Complex fill_value @@ -5048,17 +5094,35 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): expected7 = xr.concat(expected7, dim="y") expected7.name = "x" - result7 = ar0.idxmin(dim="x", fill_value=-5j) + with raise_if_dask_computes(max_computes=max_computes): + result7 = ar0.idxmin(dim="x", fill_value=-5j) assert_identical(result7, expected7) - def test_idxmax(self, x, minindex, maxindex, nanindex): - ar0 = xr.DataArray( + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask and x.dtype.kind == "M": + pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)") + + if x.dtype.kind == "O": + # TODO: nanops._nan_argminmax_object computes once to check for all-NaN slices. + max_computes = 1 + else: + max_computes = 0 + + ar0_raw = xr.DataArray( x, dims=["y", "x"], coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, attrs=self.attrs, ) + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + # No dimension specified with pytest.raises(ValueError): ar0.idxmax() @@ -5090,15 +5154,18 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): expected0.name = "x" # Default fill value (NaN) - result0 = ar0.idxmax(dim="x") + with raise_if_dask_computes(max_computes=max_computes): + result0 = ar0.idxmax(dim="x") assert_identical(result0, expected0) # Manually specify NaN fill_value - result1 = ar0.idxmax(dim="x", fill_value=np.NaN) + with raise_if_dask_computes(max_computes=max_computes): + result1 = ar0.idxmax(dim="x", fill_value=np.NaN) assert_identical(result1, expected0) # keep_attrs - result2 = ar0.idxmax(dim="x", keep_attrs=True) + with raise_if_dask_computes(max_computes=max_computes): + result2 = ar0.idxmax(dim="x", keep_attrs=True) expected2 = expected0.copy() expected2.attrs = self.attrs assert_identical(result2, expected2) @@ -5116,11 +5183,13 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): expected3.name = "x" expected3.attrs = {} - result3 = ar0.idxmax(dim="x", skipna=False) + with raise_if_dask_computes(max_computes=max_computes): + result3 = ar0.idxmax(dim="x", skipna=False) assert_identical(result3, expected3) # fill_value should be ignored with skipna=False - result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j) + with raise_if_dask_computes(max_computes=max_computes): + result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j) assert_identical(result4, expected3) # Float fill_value @@ -5132,7 +5201,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): expected5 = xr.concat(expected5, dim="y") expected5.name = "x" - result5 = ar0.idxmax(dim="x", fill_value=-1.1) + with raise_if_dask_computes(max_computes=max_computes): + result5 = ar0.idxmax(dim="x", fill_value=-1.1) assert_identical(result5, expected5) # Integer fill_value @@ -5144,7 +5214,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): expected6 = xr.concat(expected6, dim="y") expected6.name = "x" - result6 = ar0.idxmax(dim="x", fill_value=-1) + with raise_if_dask_computes(max_computes=max_computes): + result6 = ar0.idxmax(dim="x", fill_value=-1) assert_identical(result6, expected6) # Complex fill_value @@ -5156,7 +5227,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): expected7 = xr.concat(expected7, dim="y") expected7.name = "x" - result7 = ar0.idxmax(dim="x", fill_value=-5j) + with raise_if_dask_computes(max_computes=max_computes): + result7 = ar0.idxmax(dim="x", fill_value=-5j) assert_identical(result7, expected7)