From 1a0ee51d927310f8524de29fd917ac1d34890922 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 26 Apr 2025 00:00:16 -0600 Subject: [PATCH 1/2] Avoid stacking when grouping by chunked array --- doc/whats-new.rst | 2 ++ xarray/core/groupby.py | 28 +++++++++++++++++++--------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e0a9853ee45..525f550102c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -50,6 +50,8 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Avoid stacking when grouping by a chunked array. This can be a large performance improvement. + By `Deepak Cherian `_. .. _whats-new.2025.03.1: diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 6f5472a014a..bf4715dd294 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -661,18 +661,26 @@ def __init__( # specification for the groupby operation # TODO: handle obj having variables that are not present on any of the groupers # simple broadcasting fails for ExtensionArrays. - # FIXME: Skip this stacking when grouping by a dask array, it's useless in that case. - (self.group1d, self._obj, self._stacked_dim, self._inserted_dims) = _ensure_1d( - group=self.encoded.codes, obj=obj - ) - (self._group_dim,) = self.group1d.dims + codes = self.encoded.codes + self._by_chunked = is_chunked_array(codes._variable._data) + if not self._by_chunked: + (self.group1d, self._obj, self._stacked_dim, self._inserted_dims) = ( + _ensure_1d(group=codes, obj=obj) + ) + (self._group_dim,) = self.group1d.dims + else: + self.group1d = None + # This transpose preserves dim order behaviour + self._obj = obj.transpose(..., *codes.dims) + self._stacked_dim = None + self._inserted_dims = None + self._group_dim = None # cached attributes self._groups = None self._dims = None self._sizes = None self._len = len(self.encoded.full_index) - self._by_chunked = is_chunked_array(self.encoded.codes.data) @property def sizes(self) -> Mapping[Hashable, int]: @@ -817,6 +825,7 @@ def __getitem__(self, key: GroupKey) -> T_Xarray: """ Get DataArray or Dataset corresponding to a particular group label. """ + self._raise_if_by_is_chunked() return self._obj.isel({self._group_dim: self.groups[key]}) def __len__(self) -> int: @@ -1331,9 +1340,6 @@ def quantile( "Sample quantiles in statistical packages," The American Statistician, 50(4), pp. 361-365, 1996 """ - if dim is None: - dim = (self._group_dim,) - # Dataset.quantile does this, do it for flox to ensure same output. q = np.asarray(q, dtype=np.float64) @@ -1348,6 +1354,8 @@ def quantile( ) return result else: + if dim is None: + dim = (self._group_dim,) return self.map( self._obj.__class__.quantile, shortcut=False, @@ -1491,6 +1499,7 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic): @property def dims(self) -> tuple[Hashable, ...]: + self._raise_if_by_is_chunked() if self._dims is None: index = self.encoded.group_indices[0] self._dims = self._obj.isel({self._group_dim: index}).dims @@ -1702,6 +1711,7 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic): @property def dims(self) -> Frozen[Hashable, int]: + self._raise_if_by_is_chunked() if self._dims is None: index = self.encoded.group_indices[0] self._dims = self._obj.isel({self._group_dim: index}).dims From 1ab8db0160c94f29e6def46198b9455fb50475ad Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 26 Apr 2025 11:12:32 -0600 Subject: [PATCH 2/2] fix mypy --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index bf4715dd294..37c264a20b1 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -673,7 +673,7 @@ def __init__( # This transpose preserves dim order behaviour self._obj = obj.transpose(..., *codes.dims) self._stacked_dim = None - self._inserted_dims = None + self._inserted_dims = [] self._group_dim = None # cached attributes