Skip to content

Improve performance of xarray.corr() on big datasets #4804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
kathoef opened this issue Jan 13, 2021 · 9 comments · Fixed by #5284
Open

Improve performance of xarray.corr() on big datasets #4804

kathoef opened this issue Jan 13, 2021 · 9 comments · Fixed by #5284

Comments

@kathoef
Copy link

kathoef commented Jan 13, 2021

Is your feature request related to a problem? Please describe.

I calculated correlation coefficients based on datasets with sizes between 90-180 GB using xarray and Dask distributed and experienced very low performance for the xarray.corr() function. By observing the Dask dashboard it seemed that during the calculation the whole datasets are loaded from disk several times which, given the size of my datasets, became for some of the calculations a major "performance bottleneck".

Describe the solution you'd like

The problem became so annoying that I implemented my own function to calculate the correlation coefficient (thanks @willirath!), which is considerably more performant (especially for the big datasets!), because it only touches the full data once. I have uploaded a Jupyter notebook that shows the equivalence of the xarray.corr() function and my implementation (using an "unaligned data with nan values"-example, which is what xarray.corr() covers) and an example based on Dask arrays, which demonstrates the performance problems that I have stated above, and also that the xarray.corr() function is not fully lazy. (Which I assume is actually not very desirable?)

At the moment, I think, in terms of improving big data performance, a considerable improvement could be achieved by removing the if not valid_values.all() clause here, because that seems to determine that a call of xarray.corr() is not fully lazy and causes the first (of several?) full touches of the datasets? I haven't checked what's going on afterwards, but maybe that is already a useful starting point? 🤔

@mathause
Copy link
Collaborator

Yes if not valid_values.all() is not lazy. That's the same problem as #4541 and therefore #4559 can be an inspiration how to tackle this. It would be good to test if the check also makes this slower for numpy arrays? Then it could also be removed entirely. That would be counter-intuitive for me, but it seems to be faster for dask arrays...

Other improvements

  • I am not sure if /= avoids a copy but if so, that's also a possibility to make it faster.
  • We could add a short-cut for skipna=False (would require adding this option) or dtypes that cannot have NA values as follows:
if skipna:
    # 2. Ignore the nans
    valid_values = da_a.notnull() & da_b.notnull()

    if not valid_values.all():
        da_a = da_a.where(valid_values)
        da_b = da_b.where(valid_values)

    valid_count = valid_values.sum(dim) - ddof
else:
    # shortcut for skipna=False
    # da_a and da_b are aligned, so the have the same dims and shape
    axis = da_a.get_axis_num(dim)
    valid_count = np.take(da_a.shape, axis).prod() - ddof

@aaronspring
Copy link
Contributor

We implemented xr.corr as xr.pearson_r in https://xskillscore.readthedocs.io/en/stable/api/xskillscore.pearson_r.html#xskillscore.pearson_r
and it’s ~30% faster than xr.corr see #4768

@aaronspring
Copy link
Contributor

Your function from the notebook could also easily implement the builtin weighted function

@mathause
Copy link
Collaborator

mathause commented Jan 13, 2021

@aaronspring I had a quick look at your version - do you have an idea why it is is faster? Does yours also work for dask arrays?

  • In a, b = xr.broadcast(a, b, exclude=dim) why can you exclude dim?
  • I think you could also use a, b = xr.align(a, b, exclude=dim) (broadcast has join="outer" which fills it with NA which then get ignored; align uses join="inner")
  • Does your version work if the weights contain NA?

@mathause
Copy link
Collaborator

Another possibility is to replace

cov = (demeaned_da_a * demeaned_da_b).sum(dim=dim, skipna=True, min_count=1) / (

with xr.dot. However, to do so, you need to replace NA with 0 (and I am not sure if that's worth it). Also the min_count needs to be addressed (but that should not be too difficult).

@aaronspring
Copy link
Contributor

Thanks for the suggestion with xr.align.

my speculation is that xs.pearson_r is a bit faster because we first write the whole function in numpy and then pass it through xr.apply_ufunc. I think therefore it only works for xr but not dask.da

@willirath
Copy link
Contributor

I'd also add that https://github.com/pydata/xarray/blob/master/xarray/core/computation.py#L1320_L1330 which is essentially

((x - x.mean()) * (y - y.mean())).mean()

is inferior to

(x * y).mean() - x.mean() * y.mean()

because it leads to Dask holding all chunks of x in memory (see, e.g., dask/dask#6674 for details).

@dcherian
Copy link
Contributor

@kathoef we'd be happy to merge a PR with some of the suggestions proposed here.

@dcherian
Copy link
Contributor

Reopening for the suggestions in #4804 (comment)

cc @AndrewWilliams3142 if you're looking for a small followup PR with potentially large impact :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants