Skip to content

Dask-friendly nan check in xr.corr() and xr.cov() #5284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 27, 2021
18 changes: 17 additions & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,23 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None):
# 2. Ignore the nans
valid_values = da_a.notnull() & da_b.notnull()

if not valid_values.all():
def _nan_check(d):
if d.all():
return True
else:
return False

if is_duck_dask_array(valid_values.data):
# assign to copy - else the check is not triggered
_are_there_nans = valid_values.copy(
data=valid_values.data.map_blocks(_nan_check, dtype=valid_values.dtype),
deep=False,
)

else:
_are_there_nans = _nan_check(valid_values.data)

if not _are_there_nans:
Copy link
Contributor

@dcherian dcherian May 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @AndrewWilliams3142 .

Unfortunately I don't think you can avoid this for dask arrays. Fundamentally bool(dask_array) will call compute. So this line if not _are_there_nans will call compute.

I think the only solution here is to special case this optimization for non-dask DataArray, and always call .where(valid_values) for dask arrays but maybe someone else has an idea.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I guess you could stick the

valid_values = ...
if valid_values.any():
    ...

in a helper function and map_blocks that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, ok. Silly question but does it matter that _are_there_nans is a scalar reduction of valid_values? So valid_values is still left uncomputed. Would this help the bottleneck at all?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because the program needs to know valid_blocks.any() to decide which branch of the if-else statement to execute. So it gets computed.

da_a = da_a.where(valid_values)
da_b = da_b.where(valid_values)

Expand Down