Skip to content

Slow performance of DataArray.unstack() from checking variable.data #5902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
TomAugspurger opened this issue Oct 27, 2021 · 4 comments
Closed

Comments

@TomAugspurger
Copy link
Contributor

What happened:

Calling DataArray.unstack() spends time allocating an object-dtype NumPy array from values of the pandas MultiIndex.

What you expected to happen:

Faster unstack.

Minimal Complete Verifiable Example:

import pandas as pd
import numpy as np
import xarray as xr

t = pd.date_range("2000", periods=2)
x = np.arange(1000)
y = np.arange(1000)
component = np.arange(4)

idx = pd.MultiIndex.from_product([t, y, x], names=["time", "y", "x"])

data = np.random.uniform(size=(len(idx), len(component)))
arr = xr.DataArray(
    data,
    coords={"pixel": xr.DataArray(idx, name="pixel", dims="pixel"),
            "component": xr.DataArray(component, name="component", dims="component")},
    dims=("pixel", "component")
)

%time _ = arr.unstack()
CPU times: user 6.33 s, sys: 295 ms, total: 6.62 s
Wall time: 6.62 s

Anything else we need to know?:

For this example, >99% of the time is spent at on this line:

any(is_duck_dask_array(v.data) for v in self.variables.values())
, specifically on the call to v.data for the pixel array, which is a pandas MultiIndex.

Just going by the comments, it does seem like accessing v.data is necessary to perform the check. I'm wonder if we could make is_duck_dask_array a bit smarter, to avoid unnecessarily allocating data?

Alternatively, if that's too difficult, perhaps we could add a flag to unstack to disable those checks and just take the "slow" path. In my actual use-case, the slow _unstack_full_reindex is necessary since I have large Dask Arrays. But even then, the unstack completes in less than 3s, while I was getting OOM killed on the v.data checks.

Environment:

Output of xr.show_versions()
INSTALLED VERSIONS
------------------
commit: None
python: 3.8.12 | packaged by conda-forge | (default, Sep 29 2021, 19:52:28) 
[GCC 9.4.0]
python-bits: 64
OS: Linux
OS-release: 5.4.0-1040-azure
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: C.UTF-8
LANG: C.UTF-8
LOCALE: ('en_US', 'UTF-8')
libhdf5: 1.12.1
libnetcdf: 4.8.1

xarray: 0.19.0
pandas: 1.3.3
numpy: 1.20.0
scipy: 1.7.1
netCDF4: 1.5.7
pydap: installed
h5netcdf: 0.11.0
h5py: 3.4.0
Nio: None
zarr: 2.10.1
cftime: 1.5.1
nc_time_axis: 1.3.1
PseudoNetCDF: None
rasterio: 1.2.9
cfgrib: 0.9.9.0
iris: None
bottleneck: 1.3.2
dask: 2021.08.1
distributed: 2021.08.1
matplotlib: 3.4.3
cartopy: 0.20.0
seaborn: 0.11.2
numbagg: None
pint: 0.17
setuptools: 58.0.4
pip: 20.3.4
conda: None
pytest: None
IPython: 7.28.0
sphinx: None

@TomAugspurger
Copy link
Contributor Author

TomAugspurger commented Oct 27, 2021

Oh, hmm... I'm noticing now that IndexVariable (currently) eagerly loads data into memory, so that check will always be false for the problematic IndexVariable variable.

So perhaps a slight adjustment to is_duck_dask_array to handle xarray.Variable ?

diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index 550c3587..16637574 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -4159,14 +4159,14 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping):
                 # Dask arrays don't support assignment by index, which the fast unstack
                 # function requires.
                 # https://github.com/pydata/xarray/pull/4746#issuecomment-753282125
-                any(is_duck_dask_array(v.data) for v in self.variables.values())
+                any(is_duck_dask_array(v) for v in self.variables.values())
                 # Sparse doesn't currently support (though we could special-case
                 # it)
                 # https://github.com/pydata/sparse/issues/422
-                or any(
-                    isinstance(v.data, sparse_array_type)
-                    for v in self.variables.values()
-                )
+                # or any(
+                #     isinstance(v.data, sparse_array_type)
+                #     for v in self.variables.values()
+                # )
                 or sparse
                 # Until https://github.com/pydata/xarray/pull/4751 is resolved,
                 # we check explicitly whether it's a numpy array. Once that is
@@ -4177,9 +4177,9 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping):
                 # # or any(
                 # #     isinstance(v.data, pint_array_type) for v in self.variables.values()
                 # # )
-                or any(
-                    not isinstance(v.data, np.ndarray) for v in self.variables.values()
-                )
+                # or any(
+                #     not isinstance(v.data, np.ndarray) for v in self.variables.values()
+                # )
             ):
                 result = result._unstack_full_reindex(dim, fill_value, sparse)
             else:
diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py
index d1649235..e9669105 100644
--- a/xarray/core/pycompat.py
+++ b/xarray/core/pycompat.py
@@ -44,6 +44,12 @@ class DuckArrayModule:
 
 
 def is_duck_dask_array(x):
+    from xarray.core.variable import IndexVariable, Variable
+    if isinstance(x, IndexVariable):
+        return False
+    elif isinstance(x, Variable):
+        x = x.data
+
     if DuckArrayModule("dask").available:
         from dask.base import is_dask_collection

That's completely ignoring the accesses to v.data for the sparse and pint checks, which don't look quite as easy to solve.

@dcherian
Copy link
Contributor

dcherian commented Oct 27, 2021

(warning: untested code)

Instead of looking at all of self.variables we could

nonindexes = set(self.variables) - set(self.indexes)
# or alternatively make a list of multiindex variables names and exclude those

# then the condition becomes
any(is_duck_dask_array(self.variables[v].data) for v in nonindexes)

@dcherian
Copy link
Contributor

PS: It doesn't seem like the bottleneck in your case but #5582 has an alternative proposal for unstacking dask arrays.

@TomAugspurger
Copy link
Contributor Author

Thanks @dcherian, that seems to fix this performance problem. I'll see if the tests pass and will submit a PR.

I came across #5582 while searching, thanks :)

TomAugspurger pushed a commit to TomAugspurger/xarray that referenced this issue Oct 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants