Skip to content
5 changes: 4 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ v0.18.3 (unreleased)

New Features
~~~~~~~~~~~~

- Allow assigning values to a subset of a dataset using positional or label-based
indexing (:issue:`3015`, :pull:`5362`). By `Matthias Göbel <https://github.com/matzegoebel>`_.
- Attempting to reduce a weighted object over missing dimensions now raises an error (:pull:`5362`).
By `Mattia Almansi <https://github.com/malmans2>`_.
- :py:func:`xarray.cov` and :py:func:`xarray.corr` now lazily check for missing
values if inputs are dask arrays (:issue:`4804`, :pull:`5284`).
By `Andrew Williams <https://github.com/AndrewWilliams3142>`_.


Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
27 changes: 18 additions & 9 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,24 +1352,33 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None):

# 2. Ignore the nans
valid_values = da_a.notnull() & da_b.notnull()
valid_count = valid_values.sum(dim) - ddof

if not valid_values.all():
da_a = da_a.where(valid_values)
da_b = da_b.where(valid_values)
def _get_valid_values(da, other):
"""
Function to lazily mask da_a and da_b
following a similar approach to
https://github.com/pydata/xarray/pull/4559
"""
missing_vals = np.logical_or(da.isnull(), other.isnull())
if missing_vals.any():
da = da.where(~missing_vals)
return da
else:
return da

valid_count = valid_values.sum(dim) - ddof
da_a = da_a.map_blocks(_get_valid_values, args=[da_b])
da_b = da_b.map_blocks(_get_valid_values, args=[da_a])

# 3. Detrend along the given dim
demeaned_da_a = da_a - da_a.mean(dim=dim)
demeaned_da_b = da_b - da_b.mean(dim=dim)
# https://github.com/pydata/xarray/issues/4804#issuecomment-760114285
demeaned_da_ab = (da_a * da_b) - (da_a.mean(dim=dim) * da_b.mean(dim=dim))

# 4. Compute covariance along the given dim
# N.B. `skipna=False` is required or there is a bug when computing
# auto-covariance. E.g. Try xr.cov(da,da) for
# da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])
cov = (demeaned_da_a * demeaned_da_b).sum(dim=dim, skipna=True, min_count=1) / (
valid_count
)
cov = demeaned_da_ab.sum(dim=dim, skipna=True, min_count=1) / (valid_count)

if method == "cov":
return cov
Expand Down
31 changes: 30 additions & 1 deletion xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
unified_dim_sizes,
)

from . import has_dask, requires_dask
from . import has_dask, raise_if_dask_computes, requires_dask

dask = pytest.importorskip("dask")

Expand Down Expand Up @@ -1386,6 +1386,7 @@ def arrays_w_tuples():
da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(),
xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]),
xr.DataArray([[1, 2], [np.nan, np.nan]], dims=["x", "time"]),
xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"]),
]

array_tuples = [
Expand All @@ -1394,12 +1395,40 @@ def arrays_w_tuples():
(arrays[1], arrays[1]),
(arrays[2], arrays[2]),
(arrays[2], arrays[3]),
(arrays[2], arrays[4]),
(arrays[4], arrays[2]),
(arrays[3], arrays[3]),
(arrays[4], arrays[4]),
]

return arrays, array_tuples


@pytest.mark.parametrize("ddof", [0, 1])
@pytest.mark.parametrize(
"da_a, da_b",
[
arrays_w_tuples()[1][3],
arrays_w_tuples()[1][4],
arrays_w_tuples()[1][5],
arrays_w_tuples()[1][6],
arrays_w_tuples()[1][7],
arrays_w_tuples()[1][8],
],
)
@pytest.mark.parametrize("dim", [None, "x", "time"])
def test_lazy_corrcov(da_a, da_b, dim, ddof):
# GH 5284
from dask import is_dask_collection

with raise_if_dask_computes():
cov = xr.cov(da_a.chunk(), da_b.chunk(), dim=dim, ddof=ddof)
assert is_dask_collection(cov)

corr = xr.corr(da_a.chunk(), da_b.chunk(), dim=dim)
assert is_dask_collection(corr)


@pytest.mark.parametrize("ddof", [0, 1])
@pytest.mark.parametrize(
"da_a, da_b",
Expand Down