- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 1.2k
cov() and corr() - finalization #3550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ec58dfa
              ec1dd72
              184e4ab
              ef1e25c
              fe3fe06
              d57eb7c
              13a3925
              e9a8869
              617fd43
              29b3e43
              27980c8
              5e4f002
              28b1229
              98a5f82
              ea8f6eb
              a4d0446
              5a4b97c
              1a6d939
              93dfeeb
              9a04ae3
              363e238
              48e4ad4
              ac46a43
              bc33c99
              60831ff
              3945ba1
              bfe5b2b
              24e9484
              8f2a118
              e09a607
              2055962
              a1063b5
              b694b5f
              b54487e
              5e7b32d
              bc0b100
              4ac0320
              3a7120c
              ac00070
              52ae54a
              4f09263
              4ab4af4
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -30,6 +30,8 @@ Top-level functions | |
| zeros_like | ||
| ones_like | ||
| dot | ||
| cov | ||
| corr | ||
| map_blocks | ||
|  | ||
| Dataset | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -22,9 +22,9 @@ | |
| ) | ||
|  | ||
| import numpy as np | ||
|  | ||
| import xarray as xr | ||
| from . import duck_array_ops, utils | ||
| from .alignment import deep_align | ||
| from .alignment import broadcast, deep_align | ||
| from .merge import merge_coordinates_without_align | ||
| from .pycompat import dask_array_type | ||
| from .utils import is_dict_like | ||
|  | @@ -1047,6 +1047,175 @@ def earth_mover_distance(first_samples, | |
| return apply_array_ufunc(func, *args, dask=dask) | ||
|  | ||
|  | ||
| def cov(da_a, da_b, dim=None, ddof=1): | ||
| """Compute covariance between two DataArray objects along a shared dimension. | ||
|  | ||
| Parameters | ||
| ---------- | ||
| da_a: DataArray (or Variable) object | ||
| Array to compute. | ||
| da_b: DataArray (or Variable) object | ||
| Array to compute. | ||
| dim : str, optional | ||
|         
                  r-beer marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| The dimension along which the covariance will be computed | ||
|  | ||
| Returns | ||
| ------- | ||
| covariance: DataArray | ||
|  | ||
| See also | ||
| -------- | ||
| pandas.Series.cov: corresponding pandas function | ||
| xr.corr: respective function to calculate correlation | ||
|  | ||
| Examples | ||
| -------- | ||
|  | ||
| >>> da_a = DataArray(np.random.random((3, 5)), | ||
| ... dims=("space", "time"), | ||
| ... coords=[('space', ['IA', 'IL', 'IN']), | ||
| ... ('time', pd.date_range("2000-01-01", freq="1D", periods=5))]) | ||
| >>> da_a | ||
| <xarray.DataArray (space: 3, time: 5)> | ||
| array([[0.04356841, 0.11479286, 0.70359101, 0.59072561, 0.16601438], | ||
| [0.81552383, 0.72304926, 0.77644406, 0.05788198, 0.74065536], | ||
| [0.96252519, 0.36877741, 0.22248412, 0.55185954, 0.23547536]]) | ||
| Coordinates: | ||
| * space (space) <U2 'IA' 'IL' 'IN' | ||
| * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 | ||
|  | ||
| >>> da_b = DataArray(np.random.random((3, 5)), | ||
| ... dims=("space", "time"), | ||
| ... coords=[('space', ['IA', 'IL', 'IN']), | ||
| ... ('time', pd.date_range("2000-01-01", freq="1D", periods=5))]) | ||
| >>> da_b | ||
| <xarray.DataArray (space: 3, time: 5)> | ||
| array([[0.41505599, 0.43002193, 0.45250454, 0.57701084, 0.5327754 ], | ||
| [0.0998048 , 0.67225522, 0.4234324 , 0.13514615, 0.4399088 ], | ||
| [0.24675048, 0.58555283, 0.1942955 , 0.86128908, 0.05068975]]) | ||
| Coordinates: | ||
| * space (space) <U2 'IA' 'IL' 'IN' | ||
| * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 | ||
| >>> xr.cov(da_a, da_b) | ||
| <xarray.DataArray ()> | ||
| array(0.03823054) | ||
| >>> xr.cov(da_a, da_b, dim='time') | ||
| <xarray.DataArray (space: 3)> | ||
| array([0.00207952, 0.01024296, 0.08214707]) | ||
| Coordinates: | ||
| * space (space) <U2 'IA' 'IL' 'IN' | ||
| """ | ||
|  | ||
| # 1. Broadcast the two arrays | ||
| da_a, da_b = broadcast(da_a, da_b) | ||
|  | ||
| # 2. Ignore the nans | ||
| valid_values = da_a.notnull() & da_b.notnull() | ||
| # TODO: avoid drop | ||
| da_a = da_a.where(valid_values, drop=True) | ||
| da_b = da_b.where(valid_values, drop=True) | ||
|         
                  r-beer marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| valid_count = valid_values.sum(dim) - ddof | ||
|  | ||
| # if dim is not None: | ||
| # valid_count = da_a[dim].size - ddof | ||
| # else: | ||
| # valid_count = da_a.size | ||
|  | ||
| # 3. Compute mean and standard deviation along the given dim | ||
| demeaned_da_a = da_a - da_a.mean(dim=dim) | ||
| demeaned_da_b = da_b - da_b.mean(dim=dim) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not demean in place, and avoid additional copies of the full arrays? Aren't  | ||
|  | ||
| # 4. Compute covariance along the given dim | ||
| cov = (demeaned_da_a * demeaned_da_b).sum(dim=dim) / (valid_count) | ||
|  | ||
|         
                  r-beer marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| return xr.DataArray(cov) | ||
|  | ||
|  | ||
| def corr(da_a, da_b, dim=None, ddof=0): | ||
| """Compute the Pearson correlation coefficient between two DataArray objects along a shared dimension. | ||
|  | ||
| Parameters | ||
| ---------- | ||
| da_a: DataArray (or Variable) object | ||
| Array to compute. | ||
| da_b: DataArray (or Variable) object | ||
| Array to compute. | ||
| dim: str, optional | ||
| The dimension along which the correlation will be computed | ||
|  | ||
| Returns | ||
| ------- | ||
| correlation: DataArray | ||
|  | ||
| See also | ||
| -------- | ||
| pandas.Series.corr: corresponding pandas function | ||
|         
                  r-beer marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| xr.cov: underlying covariance function | ||
|  | ||
| Examples | ||
| -------- | ||
|  | ||
| >>> da_a = DataArray(np.random.random((3, 5)), | ||
| ... dims=("space", "time"), | ||
| ... coords=[('space', ['IA', 'IL', 'IN']), | ||
| ... ('time', pd.date_range("2000-01-01", freq="1D", periods=5))]) | ||
| >>> da_a | ||
| <xarray.DataArray (space: 3, time: 5)> | ||
| array([[0.04356841, 0.11479286, 0.70359101, 0.59072561, 0.16601438], | ||
| [0.81552383, 0.72304926, 0.77644406, 0.05788198, 0.74065536], | ||
| [0.96252519, 0.36877741, 0.22248412, 0.55185954, 0.23547536]]) | ||
| Coordinates: | ||
| * space (space) <U2 'IA' 'IL' 'IN' | ||
| * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 | ||
|  | ||
| >>> da_b = DataArray(np.random.random((3, 5)), | ||
| ... dims=("space", "time"), | ||
| ... coords=[('space', ['IA', 'IL', 'IN']), | ||
| ... ('time', pd.date_range("2000-01-01", freq="1D", periods=5))]) | ||
| >>> da_b | ||
| <xarray.DataArray (space: 3, time: 5)> | ||
| array([[0.41505599, 0.43002193, 0.45250454, 0.57701084, 0.5327754 ], | ||
| [0.0998048 , 0.67225522, 0.4234324 , 0.13514615, 0.4399088 ], | ||
| [0.24675048, 0.58555283, 0.1942955 , 0.86128908, 0.05068975]]) | ||
| Coordinates: | ||
| * space (space) <U2 'IA' 'IL' 'IN' | ||
| * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 | ||
| >>> xr.corr(da_a, da_b) | ||
| <xarray.DataArray ()> | ||
| array(0.67407116) | ||
| >>> xr.corr(da_a, da_b, dim='time') | ||
| <xarray.DataArray (space: 3)> | ||
| array([0.23150267, 0.24900968, 0.9061562 ]) | ||
| Coordinates: | ||
| * space (space) <U2 'IA' 'IL' 'IN' | ||
| """ | ||
| from .dataarray import DataArray | ||
|  | ||
| if any(not isinstance(arr, (Variable, DataArray)) for arr in [da_a, da_b]): | ||
| raise TypeError( | ||
| "Only xr.DataArray and xr.Variable are supported." | ||
| "Given {}.".format([type(arr) for arr in [da_a, da_b]]) | ||
| ) | ||
|  | ||
| # 1. Broadcast the two arrays | ||
| da_a, da_b = broadcast(da_a, da_b) | ||
|  | ||
| # 2. Ignore the nans | ||
| valid_values = da_a.notnull() & da_b.notnull() | ||
| da_a = da_a.where(valid_values, drop=True) | ||
| da_b = da_b.where( | ||
| valid_values, drop=True | ||
| ) # TODO: avoid drop as explained in https://github.com/pydata/xarray/pull/2652#discussion_r245492002 | ||
|  | ||
| # 3. Compute correlation based on standard deviations and cov() | ||
| da_a_std = da_a.std(dim=dim) | ||
| da_b_std = da_b.std(dim=dim) | ||
|  | ||
| corr = cov(da_a, da_b, dim=dim, ddof=ddof) / (da_a_std * da_b_std) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is making a zillion copies of the full arrays. Please try to avoid unnecessary copies. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we're more likely to have a working performant feature by make it work first, and then making it performant... | ||
|  | ||
| return xr.DataArray(corr) | ||
|  | ||
|  | ||
| def dot(*arrays, dims=None, **kwargs): | ||
| """Generalized dot product for xarray objects. Like np.einsum, but | ||
| provides a simpler interface based on array dimensions. | ||
|  | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|  | @@ -5,7 +5,7 @@ | |||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||
| import pandas as pd | ||||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||||
| from numpy.testing import assert_array_equal | ||||||||||||||||||||||||||||||||
| from numpy.testing import assert_array_equal, assert_allclose | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| import xarray as xr | ||||||||||||||||||||||||||||||||
| from xarray.core.computation import ( | ||||||||||||||||||||||||||||||||
|  | @@ -789,6 +789,68 @@ def func(x): | |||||||||||||||||||||||||||||||
| assert_identical(expected, actual) | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| da = xr.DataArray(np.random.random((3, 21, 4)), | ||||||||||||||||||||||||||||||||
| coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, | ||||||||||||||||||||||||||||||||
| dims=("a", "time", "x"),) | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| arrays = [ | ||||||||||||||||||||||||||||||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue with putting these at the module level (and one of the benefits of fixtures) is that each time this file is imported—even when no relevant test is being run—this will execute, and so it'll slow down rapid test collection. Could we put this in a fixture? I can help show you how if it's not clear. Thanks @r-beer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I planned to make a fixture out of it. (How) Can I then easily use this fixture within parametrize? The following code does not work, as the function is not iterable: @pytest.fixture()
def array_tuples():
    da = xr.DataArray(np.random.random((3, 21, 4)),
                      coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
                      dims=("a", "time", "x"),)
    arrays = [
        da.isel(time=range(0, 18)),
        da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(dim="time"),
        xr.DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time"),
        xr.DataArray([1, 1, np.nan, 2, np.nan, 3, 5, 4, 6, np.nan, 7], dims="time"),
        xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]),
        xr.DataArray([[1, 2], [np.nan, np.nan]], dims=["x", "time"]),
    ]
    array_tuples = [
        (arrays[0], arrays[0]),
        (arrays[0], arrays[1]),
        (arrays[1], arrays[1]),
        (arrays[2], arrays[2]),
        (arrays[2], arrays[3]),
        (arrays[3], arrays[3]),
        (arrays[4], arrays[4]),
        (arrays[4], arrays[5]),
        (arrays[5], arrays[5]),
    ]
    return array_tuples
@pytest.mark.parametrize("da_a, da_b", array_tuples)
@pytest.mark.parametrize("dim", [None, "time", "x"])
def test_cov(da_a, da_b, dim):
    def pandas_cov(ts1, ts2):Errors in: E   TypeError: 'function' object is not iterableThere was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want to use  5P2 combinations is fine for performance, particularly if the inputs are differentiated and so each test has marginal value. It does limit you to tests with a known correct implementation, though (you're not going to want to write out the results for all of those!) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
 That means, that all the permutation is done within the fixture? It's not possible to put the fixture into the parametrize at all? 
 Haha, yeah, don't want to write all those by hand. 🙂 Thanks a lot for all the explanations! | ||||||||||||||||||||||||||||||||
| da.isel(time=range(0, 18)), | ||||||||||||||||||||||||||||||||
| da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(dim="time"), | ||||||||||||||||||||||||||||||||
| xr.DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time"), | ||||||||||||||||||||||||||||||||
| xr.DataArray([1, 1, np.nan, 2, np.nan, 3, 5, 4, 6, np.nan, 7], dims="time"), | ||||||||||||||||||||||||||||||||
| xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]), | ||||||||||||||||||||||||||||||||
| xr.DataArray([[1, 2], [np.nan, np.nan]], dims=["x", "time"]), | ||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| array_tuples = [ | ||||||||||||||||||||||||||||||||
| (arrays[0], arrays[0]), | ||||||||||||||||||||||||||||||||
| (arrays[0], arrays[1]), | ||||||||||||||||||||||||||||||||
| (arrays[1], arrays[1]), | ||||||||||||||||||||||||||||||||
| (arrays[2], arrays[2]), | ||||||||||||||||||||||||||||||||
| (arrays[2], arrays[3]), | ||||||||||||||||||||||||||||||||
| (arrays[3], arrays[3]), | ||||||||||||||||||||||||||||||||
| (arrays[4], arrays[4]), | ||||||||||||||||||||||||||||||||
| (arrays[4], arrays[5]), | ||||||||||||||||||||||||||||||||
| (arrays[5], arrays[5]), | ||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +805
     to 
      +815
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like a good use case for  
        Suggested change
       
 which means that you don't have to update  Edit: github apparently misunderstands multi-line suggestions, you'd have to edit by hand. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want all permutations of the two variables, just pass them as separate parameters with pytest There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, I mistook  So never mind. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the feedback. I have a questions If so,  instead of using  @pytest.mark.parametrize("da_b", arrays)
@pytest.mark.parametrize("da_a", arrays)
@pytest.mark.parametrize("dim", [None, "time", "x"])
def test_cov(da_a, da_b, dim):This is, what @max-sixty means, right? However, doesn't using all combinations might blow up the number of test functions unnecessarily? On the other hand not using  | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("da_a, da_b", array_tuples) | ||||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("dim", [None, "time", "x"]) | ||||||||||||||||||||||||||||||||
| def test_cov(da_a, da_b, dim): | ||||||||||||||||||||||||||||||||
| def pandas_cov(ts1, ts2): | ||||||||||||||||||||||||||||||||
| """Ensure the ts are aligned and missing values ignored""" | ||||||||||||||||||||||||||||||||
| ts1, ts2 = xr.align(ts1, ts2) | ||||||||||||||||||||||||||||||||
| valid_values = ts1.notnull() & ts2.notnull() | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| ts1 = ts1.where(valid_values, drop=True) | ||||||||||||||||||||||||||||||||
| ts2 = ts2.where(valid_values, drop=True) | ||||||||||||||||||||||||||||||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would have thought that we should be testing that we get a similar result from pandas'  | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| return ts1.to_series().cov(ts2.to_series()) | ||||||||||||||||||||||||||||||||
|         
                  r-beer marked this conversation as resolved.
              Show resolved
            Hide resolved | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| expected = pandas_cov(da_a, da_b) | ||||||||||||||||||||||||||||||||
| actual = xr.cov(da_a, da_b, dim) | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| assert_allclose(actual, expected) | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("da_a, da_b", array_tuples) | ||||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("dim", [None, "time", "x"]) | ||||||||||||||||||||||||||||||||
| def test_corr(da_a, da_b, dim): | ||||||||||||||||||||||||||||||||
| def pandas_corr(ts1, ts2): | ||||||||||||||||||||||||||||||||
| """Ensure the ts are aligned and missing values ignored""" | ||||||||||||||||||||||||||||||||
| ts1, ts2 = xr.align(ts1, ts2) | ||||||||||||||||||||||||||||||||
| valid_values = ts1.notnull() & ts2.notnull() | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| ts1 = ts1.where(valid_values, drop=True) | ||||||||||||||||||||||||||||||||
| ts2 = ts2.where(valid_values, drop=True) | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| return ts1.to_series().corr(ts2.to_series()) | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| expected = pandas_corr(da_a, da_b) | ||||||||||||||||||||||||||||||||
| actual = xr.corr(da_a, da_b, dim) | ||||||||||||||||||||||||||||||||
| assert_allclose(actual, expected) | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
| def pandas_median(x): | ||||||||||||||||||||||||||||||||
| return pd.Series(x).median() | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skipna=None?