Skip to content

Commit 3b81a86

Browse files
Dask-friendly nan check in xr.corr() and xr.cov() (#5284)
* initial changes * Using map_blocks to lazily mask input arrays, following #4559 * Adding lazy corr cov test with `raise_if_dask_computes` * adding test for one da without nans * checking ordering of arrays doesnt matter Co-authored-by: Deepak Cherian <[email protected]> * adjust inputs to test * add test for array with no missing values * added whatsnew * fixing format issues Co-authored-by: Deepak Cherian <[email protected]>
1 parent a6a1e48 commit 3b81a86

File tree

3 files changed

+49
-6
lines changed

3 files changed

+49
-6
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@ v0.18.3 (unreleased)
2121

2222
New Features
2323
~~~~~~~~~~~~
24-
2524
- Allow assigning values to a subset of a dataset using positional or label-based
2625
indexing (:issue:`3015`, :pull:`5362`). By `Matthias Göbel <https://github.com/matzegoebel>`_.
2726
- Attempting to reduce a weighted object over missing dimensions now raises an error (:pull:`5362`).
2827
By `Mattia Almansi <https://github.com/malmans2>`_.
28+
- :py:func:`xarray.cov` and :py:func:`xarray.corr` now lazily check for missing
29+
values if inputs are dask arrays (:issue:`4804`, :pull:`5284`).
30+
By `Andrew Williams <https://github.com/AndrewWilliams3142>`_.
31+
2932

3033
Breaking changes
3134
~~~~~~~~~~~~~~~~

xarray/core/computation.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,12 +1352,23 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None):
13521352

13531353
# 2. Ignore the nans
13541354
valid_values = da_a.notnull() & da_b.notnull()
1355+
valid_count = valid_values.sum(dim) - ddof
13551356

1356-
if not valid_values.all():
1357-
da_a = da_a.where(valid_values)
1358-
da_b = da_b.where(valid_values)
1357+
def _get_valid_values(da, other):
1358+
"""
1359+
Function to lazily mask da_a and da_b
1360+
following a similar approach to
1361+
https://github.com/pydata/xarray/pull/4559
1362+
"""
1363+
missing_vals = np.logical_or(da.isnull(), other.isnull())
1364+
if missing_vals.any():
1365+
da = da.where(~missing_vals)
1366+
return da
1367+
else:
1368+
return da
13591369

1360-
valid_count = valid_values.sum(dim) - ddof
1370+
da_a = da_a.map_blocks(_get_valid_values, args=[da_b])
1371+
da_b = da_b.map_blocks(_get_valid_values, args=[da_a])
13611372

13621373
# 3. Detrend along the given dim
13631374
demeaned_da_a = da_a - da_a.mean(dim=dim)

xarray/tests/test_computation.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
unified_dim_sizes,
2323
)
2424

25-
from . import has_dask, requires_dask
25+
from . import has_dask, raise_if_dask_computes, requires_dask
2626

2727
dask = pytest.importorskip("dask")
2828

@@ -1386,6 +1386,7 @@ def arrays_w_tuples():
13861386
da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(),
13871387
xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]),
13881388
xr.DataArray([[1, 2], [np.nan, np.nan]], dims=["x", "time"]),
1389+
xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"]),
13891390
]
13901391

13911392
array_tuples = [
@@ -1394,12 +1395,40 @@ def arrays_w_tuples():
13941395
(arrays[1], arrays[1]),
13951396
(arrays[2], arrays[2]),
13961397
(arrays[2], arrays[3]),
1398+
(arrays[2], arrays[4]),
1399+
(arrays[4], arrays[2]),
13971400
(arrays[3], arrays[3]),
1401+
(arrays[4], arrays[4]),
13981402
]
13991403

14001404
return arrays, array_tuples
14011405

14021406

1407+
@pytest.mark.parametrize("ddof", [0, 1])
1408+
@pytest.mark.parametrize(
1409+
"da_a, da_b",
1410+
[
1411+
arrays_w_tuples()[1][3],
1412+
arrays_w_tuples()[1][4],
1413+
arrays_w_tuples()[1][5],
1414+
arrays_w_tuples()[1][6],
1415+
arrays_w_tuples()[1][7],
1416+
arrays_w_tuples()[1][8],
1417+
],
1418+
)
1419+
@pytest.mark.parametrize("dim", [None, "x", "time"])
1420+
def test_lazy_corrcov(da_a, da_b, dim, ddof):
1421+
# GH 5284
1422+
from dask import is_dask_collection
1423+
1424+
with raise_if_dask_computes():
1425+
cov = xr.cov(da_a.chunk(), da_b.chunk(), dim=dim, ddof=ddof)
1426+
assert is_dask_collection(cov)
1427+
1428+
corr = xr.corr(da_a.chunk(), da_b.chunk(), dim=dim)
1429+
assert is_dask_collection(corr)
1430+
1431+
14031432
@pytest.mark.parametrize("ddof", [0, 1])
14041433
@pytest.mark.parametrize(
14051434
"da_a, da_b",

0 commit comments

Comments
 (0)