Skip to content

Commit e4f8768

Browse files
authored
Avoid stacking when grouping by chunked array (#10254)
1 parent bd10f9f commit e4f8768

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ Documentation
7878

7979
Internal Changes
8080
~~~~~~~~~~~~~~~~
81+
- Avoid stacking when grouping by a chunked array. This can be a large performance improvement.
82+
By `Deepak Cherian <https://github.com/dcherian>`_.
8183

8284
.. _whats-new.2025.03.1:
8385

xarray/core/groupby.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -661,18 +661,26 @@ def __init__(
661661
# specification for the groupby operation
662662
# TODO: handle obj having variables that are not present on any of the groupers
663663
# simple broadcasting fails for ExtensionArrays.
664-
# FIXME: Skip this stacking when grouping by a dask array, it's useless in that case.
665-
(self.group1d, self._obj, self._stacked_dim, self._inserted_dims) = _ensure_1d(
666-
group=self.encoded.codes, obj=obj
667-
)
668-
(self._group_dim,) = self.group1d.dims
664+
codes = self.encoded.codes
665+
self._by_chunked = is_chunked_array(codes._variable._data)
666+
if not self._by_chunked:
667+
(self.group1d, self._obj, self._stacked_dim, self._inserted_dims) = (
668+
_ensure_1d(group=codes, obj=obj)
669+
)
670+
(self._group_dim,) = self.group1d.dims
671+
else:
672+
self.group1d = None
673+
# This transpose preserves dim order behaviour
674+
self._obj = obj.transpose(..., *codes.dims)
675+
self._stacked_dim = None
676+
self._inserted_dims = []
677+
self._group_dim = None
669678

670679
# cached attributes
671680
self._groups = None
672681
self._dims = None
673682
self._sizes = None
674683
self._len = len(self.encoded.full_index)
675-
self._by_chunked = is_chunked_array(self.encoded.codes.data)
676684

677685
@property
678686
def sizes(self) -> Mapping[Hashable, int]:
@@ -817,6 +825,7 @@ def __getitem__(self, key: GroupKey) -> T_Xarray:
817825
"""
818826
Get DataArray or Dataset corresponding to a particular group label.
819827
"""
828+
self._raise_if_by_is_chunked()
820829
return self._obj.isel({self._group_dim: self.groups[key]})
821830

822831
def __len__(self) -> int:
@@ -1331,9 +1340,6 @@ def quantile(
13311340
"Sample quantiles in statistical packages,"
13321341
The American Statistician, 50(4), pp. 361-365, 1996
13331342
"""
1334-
if dim is None:
1335-
dim = (self._group_dim,)
1336-
13371343
# Dataset.quantile does this, do it for flox to ensure same output.
13381344
q = np.asarray(q, dtype=np.float64)
13391345

@@ -1348,6 +1354,8 @@ def quantile(
13481354
)
13491355
return result
13501356
else:
1357+
if dim is None:
1358+
dim = (self._group_dim,)
13511359
return self.map(
13521360
self._obj.__class__.quantile,
13531361
shortcut=False,
@@ -1491,6 +1499,7 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic):
14911499

14921500
@property
14931501
def dims(self) -> tuple[Hashable, ...]:
1502+
self._raise_if_by_is_chunked()
14941503
if self._dims is None:
14951504
index = self.encoded.group_indices[0]
14961505
self._dims = self._obj.isel({self._group_dim: index}).dims
@@ -1702,6 +1711,7 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic):
17021711

17031712
@property
17041713
def dims(self) -> Frozen[Hashable, int]:
1714+
self._raise_if_by_is_chunked()
17051715
if self._dims is None:
17061716
index = self.encoded.group_indices[0]
17071717
self._dims = self._obj.isel({self._group_dim: index}).dims

0 commit comments

Comments
 (0)