Skip to content

Avoid stacking when grouping by chunked array #10254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ Documentation

Internal Changes
~~~~~~~~~~~~~~~~
- Avoid stacking when grouping by a chunked array. This can be a large performance improvement.
By `Deepak Cherian <https://github.com/dcherian>`_.

.. _whats-new.2025.03.1:

Expand Down
28 changes: 19 additions & 9 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,18 +661,26 @@ def __init__(
# specification for the groupby operation
# TODO: handle obj having variables that are not present on any of the groupers
# simple broadcasting fails for ExtensionArrays.
# FIXME: Skip this stacking when grouping by a dask array, it's useless in that case.
(self.group1d, self._obj, self._stacked_dim, self._inserted_dims) = _ensure_1d(
group=self.encoded.codes, obj=obj
)
(self._group_dim,) = self.group1d.dims
codes = self.encoded.codes
self._by_chunked = is_chunked_array(codes._variable._data)
if not self._by_chunked:
(self.group1d, self._obj, self._stacked_dim, self._inserted_dims) = (
_ensure_1d(group=codes, obj=obj)
)
(self._group_dim,) = self.group1d.dims
else:
self.group1d = None
# This transpose preserves dim order behaviour
self._obj = obj.transpose(..., *codes.dims)
self._stacked_dim = None
self._inserted_dims = []
self._group_dim = None

# cached attributes
self._groups = None
self._dims = None
self._sizes = None
self._len = len(self.encoded.full_index)
self._by_chunked = is_chunked_array(self.encoded.codes.data)

@property
def sizes(self) -> Mapping[Hashable, int]:
Expand Down Expand Up @@ -817,6 +825,7 @@ def __getitem__(self, key: GroupKey) -> T_Xarray:
"""
Get DataArray or Dataset corresponding to a particular group label.
"""
self._raise_if_by_is_chunked()
return self._obj.isel({self._group_dim: self.groups[key]})

def __len__(self) -> int:
Expand Down Expand Up @@ -1331,9 +1340,6 @@ def quantile(
"Sample quantiles in statistical packages,"
The American Statistician, 50(4), pp. 361-365, 1996
"""
if dim is None:
dim = (self._group_dim,)

# Dataset.quantile does this, do it for flox to ensure same output.
q = np.asarray(q, dtype=np.float64)

Expand All @@ -1348,6 +1354,8 @@ def quantile(
)
return result
else:
if dim is None:
dim = (self._group_dim,)
return self.map(
self._obj.__class__.quantile,
shortcut=False,
Expand Down Expand Up @@ -1491,6 +1499,7 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic):

@property
def dims(self) -> tuple[Hashable, ...]:
self._raise_if_by_is_chunked()
if self._dims is None:
index = self.encoded.group_indices[0]
self._dims = self._obj.isel({self._group_dim: index}).dims
Expand Down Expand Up @@ -1702,6 +1711,7 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic):

@property
def dims(self) -> Frozen[Hashable, int]:
self._raise_if_by_is_chunked()
if self._dims is None:
index = self.encoded.group_indices[0]
self._dims = self._obj.isel({self._group_dim: index}).dims
Expand Down
Loading