From 55dd4ebff21b1002bfacf91ab9c26a7c1903397e Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 5 Dec 2019 16:41:59 -0700 Subject: [PATCH 1/6] Add nanmedian for dask arrays Close #2999 --- doc/whats-new.rst | 3 ++ xarray/core/dask_array_compat.py | 74 ++++++++++++++++++++++++++++++++ xarray/core/duck_array_ops.py | 8 ++-- xarray/core/nanops.py | 5 ++- xarray/tests/test_dask.py | 2 +- 5 files changed, 86 insertions(+), 6 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 554f0bc4695..47e89e2be85 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,6 +25,9 @@ Breaking changes New Features ~~~~~~~~~~~~ +- Implement :py:func:`median` and :py:func:`nanmedian` for dask arrays. This works by rechunking + to a single chunk along all reduction axes. (:issue:`2999`). + By `Deepak Cherian `_. - :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` and ``GroupBy.quantile`` now work with dask Variables. By `Deepak Cherian `_. diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index c3dbdd27098..a21217958e5 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -1,4 +1,5 @@ from distutils.version import LooseVersion +from typing import Iterable import dask.array as da import numpy as np @@ -89,3 +90,76 @@ def meta_from_array(x, ndim=None, dtype=None): meta = meta.astype(dtype) return meta + + +if LooseVersion(dask_version) >= LooseVersion("2.8.1"): + median = da.median +else: + # Copied from dask v2.8.1 + # Used under the terms of Dask's license, see licenses/DASK_LICENSE. + def median(a, axis=None, keepdims=False): + """ + This works by automatically chunking the reduced axes to a single chunk + and then calling ``numpy.median`` function across the remaining dimensions + """ + + if axis is None: + raise NotImplementedError( + "The da.median function only works along an axis. " + "The full algorithm is difficult to do in parallel" + ) + + if not isinstance(axis, Iterable): + axis = (axis,) + + axis = [ax + a.ndim if ax < 0 else ax for ax in axis] + + a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)}) + + result = a.map_blocks( + np.median, + axis=axis, + keepdims=keepdims, + drop_axis=axis if not keepdims else None, + chunks=[1 if ax in axis else c for ax, c in enumerate(a.chunks)] + if keepdims + else None, + ) + + return result + + +if LooseVersion(dask_version) >= LooseVersion("2.8.2"): + nanmedian = da.nanmedian +else: + + def nanmedian(a, axis=None, keepdims=False): + """ + This works by automatically chunking the reduced axes to a single chunk + and then calling ``numpy.nanmedian`` function across the remaining dimensions + """ + + if axis is None: + raise NotImplementedError( + "The da.nanmedian function only works along an axis. " + "The full algorithm is difficult to do in parallel" + ) + + if not isinstance(axis, Iterable): + axis = (axis,) + + axis = [ax + a.ndim if ax < 0 else ax for ax in axis] + + a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)}) + + result = a.map_blocks( + np.nanmedian, + axis=axis, + keepdims=keepdims, + drop_axis=axis if not keepdims else None, + chunks=[1 if ax in axis else c for ax, c in enumerate(a.chunks)] + if keepdims + else None, + ) + + return result diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index cf616acb485..98b371ab7c3 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -11,7 +11,7 @@ import numpy as np import pandas as pd -from . import dask_array_ops, dtypes, npcompat, nputils +from . import dask_array_ops, dask_array_compat, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast from .pycompat import dask_array_type @@ -284,7 +284,7 @@ def _ignore_warnings_if(condition): yield -def _create_nan_agg_method(name, coerce_strings=False): +def _create_nan_agg_method(name, dask_module=dask_array, coerce_strings=False): from . import nanops def f(values, axis=None, skipna=None, **kwargs): @@ -301,7 +301,7 @@ def f(values, axis=None, skipna=None, **kwargs): nanname = "nan" + name func = getattr(nanops, nanname) else: - func = _dask_or_eager_func(name) + func = _dask_or_eager_func(name, dask_module=dask_module) try: return func(values, axis=axis, **kwargs) @@ -337,7 +337,7 @@ def f(values, axis=None, skipna=None, **kwargs): std.numeric_only = True var = _create_nan_agg_method("var") var.numeric_only = True -median = _create_nan_agg_method("median") +median = _create_nan_agg_method("median", dask_module=dask_array_compat) median.numeric_only = True prod = _create_nan_agg_method("prod") prod.numeric_only = True diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index f70e96217e8..ad8276ba6af 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -6,6 +6,7 @@ try: import dask.array as dask_array + from . import dask_array_compat except ImportError: dask_array = None @@ -141,7 +142,9 @@ def nanmean(a, axis=None, dtype=None, out=None): def nanmedian(a, axis=None, out=None): - return _dask_or_eager_func("nanmedian", eager_module=nputils)(a, axis=axis) + return _dask_or_eager_func( + "nanmedian", dask_module=dask_array_compat, eager_module=nputils + )(a, axis=axis) def _nanvar_object(value, axis=None, ddof=0, keepdims=False, **kwargs): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 6122e987154..6054142b994 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -216,7 +216,7 @@ def test_reduce(self): self.assertLazyAndAllClose(u.argmin(dim="x"), actual) self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) - with raises_regex(NotImplementedError, "dask"): + with raises_regex(NotImplementedError, "only works along an axis"): v.median() with raise_if_dask_computes(): v.reduce(duck_array_ops.mean) From 51fbbf6ebc0087fc0630deca4e3782913198ef75 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 7 Dec 2019 13:35:54 -0700 Subject: [PATCH 2/6] Fix tests. --- xarray/core/dask_array_compat.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index a21217958e5..de55de89f0c 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -1,9 +1,14 @@ from distutils.version import LooseVersion from typing import Iterable -import dask.array as da import numpy as np -from dask import __version__ as dask_version + +try: + import dask.array as da + from dask import __version__ as dask_version +except ImportError: + dask_version = "0.0.0" + da = None if LooseVersion(dask_version) >= LooseVersion("2.0.0"): meta_from_array = da.utils.meta_from_array @@ -129,7 +134,7 @@ def median(a, axis=None, keepdims=False): return result -if LooseVersion(dask_version) >= LooseVersion("2.8.2"): +if LooseVersion(dask_version) > LooseVersion("2.9.0"): nanmedian = da.nanmedian else: From 54bea40afe4839507153a2ce7e176889ef395e46 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 7 Dec 2019 13:49:41 -0700 Subject: [PATCH 3/6] fix import --- xarray/core/nanops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index ad8276ba6af..cf47acee94c 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -9,6 +9,7 @@ from . import dask_array_compat except ImportError: dask_array = None + dask_array_compat = None # type: ignore def _replace_nan(a, val): From 285258d7d75aa4cc8a0c02c9d918564d2dd52503 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 17 Dec 2019 09:07:24 -0700 Subject: [PATCH 4/6] Make sure that we don't rechunk the entire variable to one chunk by reducing over all dimensions. Dask raises an error when axis=None but not when axis=range(a.ndim). --- xarray/core/nanops.py | 6 ++++++ xarray/tests/test_dask.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index cf47acee94c..9f3ad49a802 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -143,6 +143,12 @@ def nanmean(a, axis=None, dtype=None, out=None): def nanmedian(a, axis=None, out=None): + # The dask algorithm works by rechunking to one chunk along axis + # Make sure we trigger the dask error when passing all dimensions + # so that we don't rechunk the entire array to one chunk and + # possibly blow memory + if axis is not None and len(axis) == a.ndim: + axis = None return _dask_or_eager_func( "nanmedian", dask_module=dask_array_compat, eager_module=nputils )(a, axis=axis) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 6054142b994..d0e2654eed3 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -218,6 +218,8 @@ def test_reduce(self): self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) with raises_regex(NotImplementedError, "only works along an axis"): v.median() + with raises_regex(NotImplementedError, "only works along an axis"): + v.median(v.dims) with raise_if_dask_computes(): v.reduce(duck_array_ops.mean) From 905a90149487be77252314807a3c9347889b10b2 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 18 Dec 2019 08:34:31 -0700 Subject: [PATCH 5/6] fix tests. --- xarray/core/nanops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 9f3ad49a802..f9989c2c8c9 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -147,7 +147,7 @@ def nanmedian(a, axis=None, out=None): # Make sure we trigger the dask error when passing all dimensions # so that we don't rechunk the entire array to one chunk and # possibly blow memory - if axis is not None and len(axis) == a.ndim: + if axis is not None and len(np.atleast_1d(axis)) == a.ndim: axis = None return _dask_or_eager_func( "nanmedian", dask_module=dask_array_compat, eager_module=nputils From 5f22e5e50b021086a2cdad2b515daf8ef556a1e6 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 27 Dec 2019 04:41:46 +0000 Subject: [PATCH 6/6] Update whats-new.rst --- doc/whats-new.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f772d95751a..00d1c50780e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -27,6 +27,7 @@ New Features ~~~~~~~~~~~~ - Implement :py:func:`median` and :py:func:`nanmedian` for dask arrays. This works by rechunking to a single chunk along all reduction axes. (:issue:`2999`). + By `Deepak Cherian `_. - :py:func:`xarray.concat` now preserves attributes from the first Variable. (:issue:`2575`, :issue:`2060`, :issue:`1614`) By `Deepak Cherian `_.