-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
cov() and corr() - finalization #3550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ec58dfa
ec1dd72
184e4ab
ef1e25c
fe3fe06
d57eb7c
13a3925
e9a8869
617fd43
29b3e43
27980c8
5e4f002
28b1229
98a5f82
ea8f6eb
a4d0446
5a4b97c
1a6d939
93dfeeb
9a04ae3
363e238
48e4ad4
ac46a43
bc33c99
60831ff
3945ba1
bfe5b2b
24e9484
8f2a118
e09a607
2055962
a1063b5
b694b5f
b54487e
5e7b32d
bc0b100
4ac0320
3a7120c
ac00070
52ae54a
4f09263
4ab4af4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,8 @@ Top-level functions | |
zeros_like | ||
ones_like | ||
dot | ||
cov | ||
corr | ||
map_blocks | ||
|
||
Dataset | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,9 +22,9 @@ | |
) | ||
|
||
import numpy as np | ||
|
||
import xarray as xr | ||
from . import duck_array_ops, utils | ||
from .alignment import deep_align | ||
from .alignment import broadcast, deep_align | ||
from .merge import merge_coordinates_without_align | ||
from .pycompat import dask_array_type | ||
from .utils import is_dict_like | ||
|
@@ -1047,6 +1047,175 @@ def earth_mover_distance(first_samples, | |
return apply_array_ufunc(func, *args, dask=dask) | ||
|
||
|
||
def cov(da_a, da_b, dim=None, ddof=1): | ||
"""Compute covariance between two DataArray objects along a shared dimension. | ||
|
||
Parameters | ||
---------- | ||
da_a: DataArray (or Variable) object | ||
Array to compute. | ||
da_b: DataArray (or Variable) object | ||
Array to compute. | ||
dim : str, optional | ||
r-beer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The dimension along which the covariance will be computed | ||
|
||
Returns | ||
------- | ||
covariance: DataArray | ||
|
||
See also | ||
-------- | ||
pandas.Series.cov: corresponding pandas function | ||
xr.corr: respective function to calculate correlation | ||
|
||
Examples | ||
-------- | ||
|
||
>>> da_a = DataArray(np.random.random((3, 5)), | ||
... dims=("space", "time"), | ||
... coords=[('space', ['IA', 'IL', 'IN']), | ||
... ('time', pd.date_range("2000-01-01", freq="1D", periods=5))]) | ||
>>> da_a | ||
<xarray.DataArray (space: 3, time: 5)> | ||
array([[0.04356841, 0.11479286, 0.70359101, 0.59072561, 0.16601438], | ||
[0.81552383, 0.72304926, 0.77644406, 0.05788198, 0.74065536], | ||
[0.96252519, 0.36877741, 0.22248412, 0.55185954, 0.23547536]]) | ||
Coordinates: | ||
* space (space) <U2 'IA' 'IL' 'IN' | ||
* time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 | ||
|
||
>>> da_b = DataArray(np.random.random((3, 5)), | ||
... dims=("space", "time"), | ||
... coords=[('space', ['IA', 'IL', 'IN']), | ||
... ('time', pd.date_range("2000-01-01", freq="1D", periods=5))]) | ||
>>> da_b | ||
<xarray.DataArray (space: 3, time: 5)> | ||
array([[0.41505599, 0.43002193, 0.45250454, 0.57701084, 0.5327754 ], | ||
[0.0998048 , 0.67225522, 0.4234324 , 0.13514615, 0.4399088 ], | ||
[0.24675048, 0.58555283, 0.1942955 , 0.86128908, 0.05068975]]) | ||
Coordinates: | ||
* space (space) <U2 'IA' 'IL' 'IN' | ||
* time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 | ||
>>> xr.cov(da_a, da_b) | ||
<xarray.DataArray ()> | ||
array(0.03823054) | ||
>>> xr.cov(da_a, da_b, dim='time') | ||
<xarray.DataArray (space: 3)> | ||
array([0.00207952, 0.01024296, 0.08214707]) | ||
Coordinates: | ||
* space (space) <U2 'IA' 'IL' 'IN' | ||
""" | ||
|
||
# 1. Broadcast the two arrays | ||
da_a, da_b = broadcast(da_a, da_b) | ||
|
||
# 2. Ignore the nans | ||
valid_values = da_a.notnull() & da_b.notnull() | ||
# TODO: avoid drop | ||
da_a = da_a.where(valid_values, drop=True) | ||
da_b = da_b.where(valid_values, drop=True) | ||
r-beer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
valid_count = valid_values.sum(dim) - ddof | ||
|
||
# if dim is not None: | ||
# valid_count = da_a[dim].size - ddof | ||
# else: | ||
# valid_count = da_a.size | ||
|
||
# 3. Compute mean and standard deviation along the given dim | ||
demeaned_da_a = da_a - da_a.mean(dim=dim) | ||
demeaned_da_b = da_b - da_b.mean(dim=dim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not demean in place, and avoid additional copies of the full arrays? Aren't |
||
|
||
# 4. Compute covariance along the given dim | ||
cov = (demeaned_da_a * demeaned_da_b).sum(dim=dim) / (valid_count) | ||
|
||
r-beer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return xr.DataArray(cov) | ||
|
||
|
||
def corr(da_a, da_b, dim=None, ddof=0): | ||
"""Compute the Pearson correlation coefficient between two DataArray objects along a shared dimension. | ||
|
||
Parameters | ||
---------- | ||
da_a: DataArray (or Variable) object | ||
Array to compute. | ||
da_b: DataArray (or Variable) object | ||
Array to compute. | ||
dim: str, optional | ||
The dimension along which the correlation will be computed | ||
|
||
Returns | ||
------- | ||
correlation: DataArray | ||
|
||
See also | ||
-------- | ||
pandas.Series.corr: corresponding pandas function | ||
r-beer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
xr.cov: underlying covariance function | ||
|
||
Examples | ||
-------- | ||
|
||
>>> da_a = DataArray(np.random.random((3, 5)), | ||
... dims=("space", "time"), | ||
... coords=[('space', ['IA', 'IL', 'IN']), | ||
... ('time', pd.date_range("2000-01-01", freq="1D", periods=5))]) | ||
>>> da_a | ||
<xarray.DataArray (space: 3, time: 5)> | ||
array([[0.04356841, 0.11479286, 0.70359101, 0.59072561, 0.16601438], | ||
[0.81552383, 0.72304926, 0.77644406, 0.05788198, 0.74065536], | ||
[0.96252519, 0.36877741, 0.22248412, 0.55185954, 0.23547536]]) | ||
Coordinates: | ||
* space (space) <U2 'IA' 'IL' 'IN' | ||
* time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 | ||
|
||
>>> da_b = DataArray(np.random.random((3, 5)), | ||
... dims=("space", "time"), | ||
... coords=[('space', ['IA', 'IL', 'IN']), | ||
... ('time', pd.date_range("2000-01-01", freq="1D", periods=5))]) | ||
>>> da_b | ||
<xarray.DataArray (space: 3, time: 5)> | ||
array([[0.41505599, 0.43002193, 0.45250454, 0.57701084, 0.5327754 ], | ||
[0.0998048 , 0.67225522, 0.4234324 , 0.13514615, 0.4399088 ], | ||
[0.24675048, 0.58555283, 0.1942955 , 0.86128908, 0.05068975]]) | ||
Coordinates: | ||
* space (space) <U2 'IA' 'IL' 'IN' | ||
* time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 | ||
>>> xr.corr(da_a, da_b) | ||
<xarray.DataArray ()> | ||
array(0.67407116) | ||
>>> xr.corr(da_a, da_b, dim='time') | ||
<xarray.DataArray (space: 3)> | ||
array([0.23150267, 0.24900968, 0.9061562 ]) | ||
Coordinates: | ||
* space (space) <U2 'IA' 'IL' 'IN' | ||
""" | ||
from .dataarray import DataArray | ||
|
||
if any(not isinstance(arr, (Variable, DataArray)) for arr in [da_a, da_b]): | ||
raise TypeError( | ||
"Only xr.DataArray and xr.Variable are supported." | ||
"Given {}.".format([type(arr) for arr in [da_a, da_b]]) | ||
) | ||
|
||
# 1. Broadcast the two arrays | ||
da_a, da_b = broadcast(da_a, da_b) | ||
|
||
# 2. Ignore the nans | ||
valid_values = da_a.notnull() & da_b.notnull() | ||
da_a = da_a.where(valid_values, drop=True) | ||
da_b = da_b.where( | ||
valid_values, drop=True | ||
) # TODO: avoid drop as explained in https://github.com/pydata/xarray/pull/2652#discussion_r245492002 | ||
|
||
# 3. Compute correlation based on standard deviations and cov() | ||
da_a_std = da_a.std(dim=dim) | ||
da_b_std = da_b.std(dim=dim) | ||
|
||
corr = cov(da_a, da_b, dim=dim, ddof=ddof) / (da_a_std * da_b_std) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is making a zillion copies of the full arrays. Please try to avoid unnecessary copies. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we're more likely to have a working performant feature by make it work first, and then making it performant... |
||
|
||
return xr.DataArray(corr) | ||
|
||
|
||
def dot(*arrays, dims=None, **kwargs): | ||
"""Generalized dot product for xarray objects. Like np.einsum, but | ||
provides a simpler interface based on array dimensions. | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5,7 +5,7 @@ | |||||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||||||
import pandas as pd | ||||||||||||||||||||||||||||||||
import pytest | ||||||||||||||||||||||||||||||||
from numpy.testing import assert_array_equal | ||||||||||||||||||||||||||||||||
from numpy.testing import assert_array_equal, assert_allclose | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
import xarray as xr | ||||||||||||||||||||||||||||||||
from xarray.core.computation import ( | ||||||||||||||||||||||||||||||||
|
@@ -789,6 +789,68 @@ def func(x): | |||||||||||||||||||||||||||||||
assert_identical(expected, actual) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
da = xr.DataArray(np.random.random((3, 21, 4)), | ||||||||||||||||||||||||||||||||
coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, | ||||||||||||||||||||||||||||||||
dims=("a", "time", "x"),) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
arrays = [ | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue with putting these at the module level (and one of the benefits of fixtures) is that each time this file is imported—even when no relevant test is being run—this will execute, and so it'll slow down rapid test collection. Could we put this in a fixture? I can help show you how if it's not clear. Thanks @r-beer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I planned to make a fixture out of it. (How) Can I then easily use this fixture within parametrize? The following code does not work, as the function is not iterable: @pytest.fixture()
def array_tuples():
da = xr.DataArray(np.random.random((3, 21, 4)),
coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
dims=("a", "time", "x"),)
arrays = [
da.isel(time=range(0, 18)),
da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(dim="time"),
xr.DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time"),
xr.DataArray([1, 1, np.nan, 2, np.nan, 3, 5, 4, 6, np.nan, 7], dims="time"),
xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]),
xr.DataArray([[1, 2], [np.nan, np.nan]], dims=["x", "time"]),
]
array_tuples = [
(arrays[0], arrays[0]),
(arrays[0], arrays[1]),
(arrays[1], arrays[1]),
(arrays[2], arrays[2]),
(arrays[2], arrays[3]),
(arrays[3], arrays[3]),
(arrays[4], arrays[4]),
(arrays[4], arrays[5]),
(arrays[5], arrays[5]),
]
return array_tuples
@pytest.mark.parametrize("da_a, da_b", array_tuples)
@pytest.mark.parametrize("dim", [None, "time", "x"])
def test_cov(da_a, da_b, dim):
def pandas_cov(ts1, ts2): Errors in: E TypeError: 'function' object is not iterable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want to use 5P2 combinations is fine for performance, particularly if the inputs are differentiated and so each test has marginal value. It does limit you to tests with a known correct implementation, though (you're not going to want to write out the results for all of those!) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That means, that all the permutation is done within the fixture? It's not possible to put the fixture into the parametrize at all?
Haha, yeah, don't want to write all those by hand. 🙂 Thanks a lot for all the explanations! |
||||||||||||||||||||||||||||||||
da.isel(time=range(0, 18)), | ||||||||||||||||||||||||||||||||
da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(dim="time"), | ||||||||||||||||||||||||||||||||
xr.DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time"), | ||||||||||||||||||||||||||||||||
xr.DataArray([1, 1, np.nan, 2, np.nan, 3, 5, 4, 6, np.nan, 7], dims="time"), | ||||||||||||||||||||||||||||||||
xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]), | ||||||||||||||||||||||||||||||||
xr.DataArray([[1, 2], [np.nan, np.nan]], dims=["x", "time"]), | ||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
array_tuples = [ | ||||||||||||||||||||||||||||||||
(arrays[0], arrays[0]), | ||||||||||||||||||||||||||||||||
(arrays[0], arrays[1]), | ||||||||||||||||||||||||||||||||
(arrays[1], arrays[1]), | ||||||||||||||||||||||||||||||||
(arrays[2], arrays[2]), | ||||||||||||||||||||||||||||||||
(arrays[2], arrays[3]), | ||||||||||||||||||||||||||||||||
(arrays[3], arrays[3]), | ||||||||||||||||||||||||||||||||
(arrays[4], arrays[4]), | ||||||||||||||||||||||||||||||||
(arrays[4], arrays[5]), | ||||||||||||||||||||||||||||||||
(arrays[5], arrays[5]), | ||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||
Comment on lines
+805
to
+815
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like a good use case for
Suggested change
which means that you don't have to update Edit: github apparently misunderstands multi-line suggestions, you'd have to edit by hand. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want all permutations of the two variables, just pass them as separate parameters with pytest There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, I mistook So never mind. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the feedback. I have a questions If so, instead of using @pytest.mark.parametrize("da_b", arrays)
@pytest.mark.parametrize("da_a", arrays)
@pytest.mark.parametrize("dim", [None, "time", "x"])
def test_cov(da_a, da_b, dim): This is, what @max-sixty means, right? However, doesn't using all combinations might blow up the number of test functions unnecessarily? On the other hand not using |
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
@pytest.mark.parametrize("da_a, da_b", array_tuples) | ||||||||||||||||||||||||||||||||
@pytest.mark.parametrize("dim", [None, "time", "x"]) | ||||||||||||||||||||||||||||||||
def test_cov(da_a, da_b, dim): | ||||||||||||||||||||||||||||||||
def pandas_cov(ts1, ts2): | ||||||||||||||||||||||||||||||||
"""Ensure the ts are aligned and missing values ignored""" | ||||||||||||||||||||||||||||||||
ts1, ts2 = xr.align(ts1, ts2) | ||||||||||||||||||||||||||||||||
valid_values = ts1.notnull() & ts2.notnull() | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
ts1 = ts1.where(valid_values, drop=True) | ||||||||||||||||||||||||||||||||
ts2 = ts2.where(valid_values, drop=True) | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would have thought that we should be testing that we get a similar result from pandas' |
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
return ts1.to_series().cov(ts2.to_series()) | ||||||||||||||||||||||||||||||||
r-beer marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
expected = pandas_cov(da_a, da_b) | ||||||||||||||||||||||||||||||||
actual = xr.cov(da_a, da_b, dim) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
assert_allclose(actual, expected) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
@pytest.mark.parametrize("da_a, da_b", array_tuples) | ||||||||||||||||||||||||||||||||
@pytest.mark.parametrize("dim", [None, "time", "x"]) | ||||||||||||||||||||||||||||||||
def test_corr(da_a, da_b, dim): | ||||||||||||||||||||||||||||||||
def pandas_corr(ts1, ts2): | ||||||||||||||||||||||||||||||||
"""Ensure the ts are aligned and missing values ignored""" | ||||||||||||||||||||||||||||||||
ts1, ts2 = xr.align(ts1, ts2) | ||||||||||||||||||||||||||||||||
valid_values = ts1.notnull() & ts2.notnull() | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
ts1 = ts1.where(valid_values, drop=True) | ||||||||||||||||||||||||||||||||
ts2 = ts2.where(valid_values, drop=True) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
return ts1.to_series().corr(ts2.to_series()) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
expected = pandas_corr(da_a, da_b) | ||||||||||||||||||||||||||||||||
actual = xr.corr(da_a, da_b, dim) | ||||||||||||||||||||||||||||||||
assert_allclose(actual, expected) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def pandas_median(x): | ||||||||||||||||||||||||||||||||
return pd.Series(x).median() | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skipna=None
?