Skip to content

Commit ad5c7ed

Browse files
authored
Add GroupBy.shuffle_to_chunks() (#9320)
* Add GroupBy.shuffle() * Cleanup * Cleanup * fix * return groupby instance from shuffle * Fix nD by * Skip if no dask * fix tests * Add `chunks` to signature * FIx self * Another Self fix * Forward chunks too * [revert] * undo flox limit * [revert] * fix types * Add DataArray.shuffle_by, Dataset.shuffle_by * Add doctest * Refactor * tweak docstrings * fix typing * Fix * fix docstring * bump min version to dask>=2024.08.1 * Fix typing * Fix types * remove shuffle_by for now. * Add tests * Support shuffling with multiple groupers * Revert "remove shuffle_by for now." This reverts commit 7a99c8f. * bad merge * Add a test * Add docs * bugfix * Refactor out Dataset._shuffle * fix types * fix tests * Handle by is chunked * Some refactoring * Remove shuffle_by * shuffle -> distributed_shuffle * return xarray object from distributed_shuffle * fix * fix doctest * fix api * Rename to `shuffle_to_chunks` * update docs
1 parent 077276a commit ad5c7ed

14 files changed

+421
-38
lines changed

doc/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,7 @@ Dataset
12101210
DatasetGroupBy.var
12111211
DatasetGroupBy.dims
12121212
DatasetGroupBy.groups
1213+
DatasetGroupBy.shuffle_to_chunks
12131214

12141215
DataArray
12151216
---------
@@ -1241,6 +1242,7 @@ DataArray
12411242
DataArrayGroupBy.var
12421243
DataArrayGroupBy.dims
12431244
DataArrayGroupBy.groups
1245+
DataArrayGroupBy.shuffle_to_chunks
12441246

12451247
Grouper Objects
12461248
---------------

doc/user-guide/groupby.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,24 @@ Different groupers can be combined to construct sophisticated GroupBy operations
330330
from xarray.groupers import BinGrouper
331331
332332
ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum()
333+
334+
335+
Shuffling
336+
~~~~~~~~~
337+
338+
Shuffling is a generalization of sorting a DataArray or Dataset by another DataArray, named ``label`` for example, that follows from the idea of grouping by ``label``.
339+
Shuffling reorders the DataArray or the DataArrays in a Dataset such that all members of a group occur sequentially. For example,
340+
Shuffle the object using either :py:class:`DatasetGroupBy` or :py:class:`DataArrayGroupBy` as appropriate.
341+
342+
.. ipython:: python
343+
344+
da = xr.DataArray(
345+
dims="x",
346+
data=[1, 2, 3, 4, 5, 6],
347+
coords={"label": ("x", "a b c a b c".split(" "))},
348+
)
349+
da.groupby("label").shuffle_to_chunks()
350+
351+
352+
For chunked array types (e.g. dask or cubed), shuffle may result in a more optimized communication pattern when compared to direct indexing by the appropriate indexer.
353+
Shuffling also makes GroupBy operations on chunked arrays an embarrassingly parallel problem, and may significantly improve workloads that use :py:meth:`DatasetGroupBy.map` or :py:meth:`DataArrayGroupBy.map`.

xarray/core/dataarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
Bins,
6464
DaCompatible,
6565
NetcdfWriteModes,
66+
T_Chunks,
6667
T_DataArray,
6768
T_DataArrayOrSet,
6869
ZarrWriteModes,
@@ -105,6 +106,7 @@
105106
Dims,
106107
ErrorOptions,
107108
ErrorOptionsWithWarn,
109+
GroupIndices,
108110
GroupInput,
109111
InterpOptions,
110112
PadModeOptions,
@@ -1687,6 +1689,12 @@ def sel(
16871689
)
16881690
return self._from_temp_dataset(ds)
16891691

1692+
def _shuffle(
1693+
self, dim: Hashable, *, indices: GroupIndices, chunks: T_Chunks
1694+
) -> Self:
1695+
ds = self._to_temp_dataset()._shuffle(dim=dim, indices=indices, chunks=chunks)
1696+
return self._from_temp_dataset(ds)
1697+
16901698
def head(
16911699
self,
16921700
indexers: Mapping[Any, int] | int | None = None,

xarray/core/dataset.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
DsCompatible,
156156
ErrorOptions,
157157
ErrorOptionsWithWarn,
158+
GroupIndices,
158159
GroupInput,
159160
InterpOptions,
160161
JoinOptions,
@@ -166,6 +167,7 @@
166167
ResampleCompatible,
167168
SideOptions,
168169
T_ChunkDimFreq,
170+
T_Chunks,
169171
T_DatasetPadConstantValues,
170172
T_Xarray,
171173
)
@@ -3237,6 +3239,38 @@ def sel(
32373239
result = self.isel(indexers=query_results.dim_indexers, drop=drop)
32383240
return result._overwrite_indexes(*query_results.as_tuple()[1:])
32393241

3242+
def _shuffle(self, dim, *, indices: GroupIndices, chunks: T_Chunks) -> Self:
3243+
# Shuffling is only different from `isel` for chunked arrays.
3244+
# Extract them out, and treat them specially. The rest, we route through isel.
3245+
# This makes it easy to ensure correct handling of indexes.
3246+
is_chunked = {
3247+
name: var
3248+
for name, var in self._variables.items()
3249+
if is_chunked_array(var._data)
3250+
}
3251+
subset = self[[name for name in self._variables if name not in is_chunked]]
3252+
3253+
no_slices: list[list[int]] = [
3254+
list(range(*idx.indices(self.sizes[dim])))
3255+
if isinstance(idx, slice)
3256+
else idx
3257+
for idx in indices
3258+
]
3259+
no_slices = [idx for idx in no_slices if idx]
3260+
3261+
shuffled = (
3262+
subset
3263+
if dim not in subset.dims
3264+
else subset.isel({dim: np.concatenate(no_slices)})
3265+
)
3266+
for name, var in is_chunked.items():
3267+
shuffled[name] = var._shuffle(
3268+
indices=no_slices,
3269+
dim=dim,
3270+
chunks=chunks,
3271+
)
3272+
return shuffled
3273+
32403274
def head(
32413275
self,
32423276
indexers: Mapping[Any, int] | int | None = None,

xarray/core/groupby.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@
5757

5858
from xarray.core.dataarray import DataArray
5959
from xarray.core.dataset import Dataset
60-
from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey
60+
from xarray.core.types import (
61+
GroupIndex,
62+
GroupIndices,
63+
GroupInput,
64+
GroupKey,
65+
T_Chunks,
66+
)
6167
from xarray.core.utils import Frozen
6268
from xarray.groupers import EncodedGroups, Grouper
6369

@@ -676,6 +682,76 @@ def sizes(self) -> Mapping[Hashable, int]:
676682
self._sizes = self._obj.isel({self._group_dim: index}).sizes
677683
return self._sizes
678684

685+
def shuffle_to_chunks(self, chunks: T_Chunks = None) -> T_Xarray:
686+
"""
687+
Sort or "shuffle" the underlying object.
688+
689+
"Shuffle" means the object is sorted so that all group members occur sequentially,
690+
in the same chunk. Multiple groups may occur in the same chunk.
691+
This method is particularly useful for chunked arrays (e.g. dask, cubed).
692+
particularly when you need to map a function that requires all members of a group
693+
to be present in a single chunk. For chunked array types, the order of appearance
694+
is not guaranteed, but will depend on the input chunking.
695+
696+
Parameters
697+
----------
698+
chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional
699+
How to adjust chunks along dimensions not present in the array being grouped by.
700+
701+
Returns
702+
-------
703+
DataArrayGroupBy or DatasetGroupBy
704+
705+
Examples
706+
--------
707+
>>> import dask.array
708+
>>> da = xr.DataArray(
709+
... dims="x",
710+
... data=dask.array.arange(10, chunks=3),
711+
... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]},
712+
... name="a",
713+
... )
714+
>>> shuffled = da.groupby("x").shuffle_to_chunks()
715+
>>> shuffled
716+
<xarray.DataArray 'a' (x: 10)> Size: 80B
717+
dask.array<shuffle, shape=(10,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
718+
Coordinates:
719+
* x (x) int64 80B 0 1 1 1 2 2 2 3 3 3
720+
721+
>>> shuffled.groupby("x").quantile(q=0.5).compute()
722+
<xarray.DataArray 'a' (x: 4)> Size: 32B
723+
array([9., 3., 4., 5.])
724+
Coordinates:
725+
quantile float64 8B 0.5
726+
* x (x) int64 32B 0 1 2 3
727+
728+
See Also
729+
--------
730+
dask.dataframe.DataFrame.shuffle
731+
dask.array.shuffle
732+
"""
733+
self._raise_if_by_is_chunked()
734+
return self._shuffle_obj(chunks)
735+
736+
def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
737+
from xarray.core.dataarray import DataArray
738+
739+
was_array = isinstance(self._obj, DataArray)
740+
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj
741+
742+
for grouper in self.groupers:
743+
if grouper.name not in as_dataset._variables:
744+
as_dataset.coords[grouper.name] = grouper.group
745+
746+
shuffled = as_dataset._shuffle(
747+
dim=self._group_dim, indices=self.encoded.group_indices, chunks=chunks
748+
)
749+
unstacked: Dataset = self._maybe_unstack(shuffled)
750+
if was_array:
751+
return self._obj._from_temp_dataset(unstacked)
752+
else:
753+
return unstacked # type: ignore[return-value]
754+
679755
def map(
680756
self,
681757
func: Callable,
@@ -896,7 +972,9 @@ def _maybe_unstack(self, obj):
896972
# and `inserted_dims`
897973
# if multiple groupers all share the same single dimension, then
898974
# we don't stack/unstack. Do that manually now.
899-
obj = obj.unstack(*self.encoded.unique_coord.dims)
975+
dims_to_unstack = self.encoded.unique_coord.dims
976+
if all(dim in obj.dims for dim in dims_to_unstack):
977+
obj = obj.unstack(*dims_to_unstack)
900978
to_drop = [
901979
grouper.name
902980
for grouper in self.groupers

xarray/core/resample.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
if TYPE_CHECKING:
1515
from xarray.core.dataarray import DataArray
1616
from xarray.core.dataset import Dataset
17+
from xarray.core.types import T_Chunks
1718

1819
from xarray.groupers import RESAMPLE_DIM
1920

@@ -58,6 +59,50 @@ def _flox_reduce(
5859
result = result.rename({RESAMPLE_DIM: self._group_dim})
5960
return result
6061

62+
def shuffle_to_chunks(self, chunks: T_Chunks = None):
63+
"""
64+
Sort or "shuffle" the underlying object.
65+
66+
"Shuffle" means the object is sorted so that all group members occur sequentially,
67+
in the same chunk. Multiple groups may occur in the same chunk.
68+
This method is particularly useful for chunked arrays (e.g. dask, cubed).
69+
particularly when you need to map a function that requires all members of a group
70+
to be present in a single chunk. For chunked array types, the order of appearance
71+
is not guaranteed, but will depend on the input chunking.
72+
73+
Parameters
74+
----------
75+
chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional
76+
How to adjust chunks along dimensions not present in the array being grouped by.
77+
78+
Returns
79+
-------
80+
DataArrayGroupBy or DatasetGroupBy
81+
82+
Examples
83+
--------
84+
>>> import dask.array
85+
>>> da = xr.DataArray(
86+
... dims="time",
87+
... data=dask.array.arange(10, chunks=1),
88+
... coords={"time": xr.date_range("2001-01-01", freq="12h", periods=10)},
89+
... name="a",
90+
... )
91+
>>> shuffled = da.resample(time="2D").shuffle_to_chunks()
92+
>>> shuffled
93+
<xarray.DataArray 'a' (time: 10)> Size: 80B
94+
dask.array<shuffle, shape=(10,), dtype=int64, chunksize=(4,), chunktype=numpy.ndarray>
95+
Coordinates:
96+
* time (time) datetime64[ns] 80B 2001-01-01 ... 2001-01-05T12:00:00
97+
98+
See Also
99+
--------
100+
dask.dataframe.DataFrame.shuffle
101+
dask.array.shuffle
102+
"""
103+
(grouper,) = self.groupers
104+
return self._shuffle_obj(chunks).drop_vars(RESAMPLE_DIM)
105+
61106
def _drop_coords(self) -> T_Xarray:
62107
"""Drop non-dimension coordinates along the resampled dimension."""
63108
obj = self._obj

xarray/core/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def read(self, __n: int = ...) -> AnyStr_co:
362362
ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"]
363363

364364
GroupKey = Any
365-
GroupIndex = Union[int, slice, list[int]]
365+
GroupIndex = Union[slice, list[int]]
366366
GroupIndices = tuple[GroupIndex, ...]
367367
Bins = Union[
368368
int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index

xarray/core/variable.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@
4545
maybe_coerce_to_str,
4646
)
4747
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
48-
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array
48+
from xarray.namedarray.parallelcompat import get_chunked_array_type
49+
from xarray.namedarray.pycompat import (
50+
integer_types,
51+
is_0d_dask_array,
52+
is_chunked_array,
53+
to_duck_array,
54+
)
4955
from xarray.namedarray.utils import module_available
5056
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims
5157

@@ -1019,6 +1025,24 @@ def compute(self, **kwargs):
10191025
new = self.copy(deep=False)
10201026
return new.load(**kwargs)
10211027

1028+
def _shuffle(
1029+
self, indices: list[list[int]], dim: Hashable, chunks: T_Chunks
1030+
) -> Self:
1031+
# TODO (dcherian): consider making this public API
1032+
array = self._data
1033+
if is_chunked_array(array):
1034+
chunkmanager = get_chunked_array_type(array)
1035+
return self._replace(
1036+
data=chunkmanager.shuffle(
1037+
array,
1038+
indexer=indices,
1039+
axis=self.get_axis_num(dim),
1040+
chunks=chunks,
1041+
)
1042+
)
1043+
else:
1044+
return self.isel({dim: np.concatenate(indices)})
1045+
10221046
def isel(
10231047
self,
10241048
indexers: Mapping[Any, Any] | None = None,

0 commit comments

Comments
 (0)