Skip to content

Commit ce841d5

Browse files
committed
Group by multiple strings
Closes pydata#9396
1 parent d33e4ad commit ce841d5

File tree

5 files changed

+85
-28
lines changed

5 files changed

+85
-28
lines changed

xarray/core/dataarray.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
Dims,
103103
ErrorOptions,
104104
ErrorOptionsWithWarn,
105+
GroupInput,
105106
InterpOptions,
106107
PadModeOptions,
107108
PadReflectOptions,
@@ -6706,10 +6707,7 @@ def interp_calendar(
67066707
@_deprecate_positional_args("v2024.07.0")
67076708
def groupby(
67086709
self,
6709-
group: (
6710-
Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
6711-
) = None,
6712-
*,
6710+
group: GroupInput = None,
67136711
squeeze: Literal[False] = False,
67146712
restore_coord_dims: bool = False,
67156713
**groupers: Grouper,
@@ -6718,7 +6716,7 @@ def groupby(
67186716
67196717
Parameters
67206718
----------
6721-
group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper
6719+
group : str or DataArray or IndexVariable or iterable of Hashable or mapping of Hashable to Grouper
67226720
Array whose unique values should be used to group this array. If a
67236721
Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary,
67246722
must map an existing variable name to a :py:class:`Grouper` instance.
@@ -6788,29 +6786,35 @@ def groupby(
67886786
Dataset.resample
67896787
DataArray.resample
67906788
"""
6789+
from xarray.core.dataarray import DataArray
67916790
from xarray.core.groupby import (
67926791
DataArrayGroupBy,
67936792
ResolvedGrouper,
6793+
_validate_group_and_groupers,
67946794
_validate_groupby_squeeze,
67956795
)
6796+
from xarray.core.variable import Variable
67966797
from xarray.groupers import UniqueGrouper
67976798

67986799
_validate_groupby_squeeze(squeeze)
6800+
_validate_group_and_groupers(group, groupers)
67996801

68006802
if isinstance(group, Mapping):
68016803
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
68026804
group = None
68036805

68046806
rgroupers: tuple[ResolvedGrouper, ...]
6805-
if group is not None:
6806-
if groupers:
6807-
raise ValueError(
6808-
"Providing a combination of `group` and **groupers is not supported."
6809-
)
6807+
if isinstance(group, DataArray | Variable):
68106808
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
68116809
else:
6812-
if not groupers:
6813-
raise ValueError("Either `group` or `**groupers` must be provided.")
6810+
if group is not None:
6811+
if TYPE_CHECKING:
6812+
assert isinstance(group, str | Iterable)
6813+
group_iter: Iterable[Hashable] = (
6814+
(group,) if isinstance(group, str) else group
6815+
)
6816+
groupers = {g: UniqueGrouper() for g in group_iter}
6817+
68146818
rgroupers = tuple(
68156819
ResolvedGrouper(grouper, group, self)
68166820
for group, grouper in groupers.items()

xarray/core/dataset.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@
154154
DsCompatible,
155155
ErrorOptions,
156156
ErrorOptionsWithWarn,
157+
GroupInput,
157158
InterpOptions,
158159
JoinOptions,
159160
PadModeOptions,
@@ -10331,10 +10332,7 @@ def interp_calendar(
1033110332
@_deprecate_positional_args("v2024.07.0")
1033210333
def groupby(
1033310334
self,
10334-
group: (
10335-
Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
10336-
) = None,
10337-
*,
10335+
group: GroupInput = None,
1033810336
squeeze: Literal[False] = False,
1033910337
restore_coord_dims: bool = False,
1034010338
**groupers: Grouper,
@@ -10343,7 +10341,7 @@ def groupby(
1034310341
1034410342
Parameters
1034510343
----------
10346-
group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper
10344+
group : str or DataArray or IndexVariable or sequence of hashable or mapping of Hashable to Grouper
1034710345
Array whose unique values should be used to group this array. If a
1034810346
Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary,
1034910347
must map an existing variable name to a :py:class:`Grouper` instance.
@@ -10384,29 +10382,35 @@ def groupby(
1038410382
Dataset.resample
1038510383
DataArray.resample
1038610384
"""
10385+
from xarray.core.dataarray import DataArray
1038710386
from xarray.core.groupby import (
1038810387
DatasetGroupBy,
1038910388
ResolvedGrouper,
10389+
_validate_group_and_groupers,
1039010390
_validate_groupby_squeeze,
1039110391
)
10392+
from xarray.core.variable import Variable
1039210393
from xarray.groupers import UniqueGrouper
1039310394

1039410395
_validate_groupby_squeeze(squeeze)
10396+
_validate_group_and_groupers(group, groupers)
1039510397

1039610398
if isinstance(group, Mapping):
1039710399
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
1039810400
group = None
1039910401

1040010402
rgroupers: tuple[ResolvedGrouper, ...]
10401-
if group is not None:
10402-
if groupers:
10403-
raise ValueError(
10404-
"Providing a combination of `group` and **groupers is not supported."
10405-
)
10403+
if isinstance(group, DataArray | Variable):
1040610404
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
1040710405
else:
10408-
if not groupers:
10409-
raise ValueError("Either `group` or `**groupers` must be provided.")
10406+
if group is not None:
10407+
if TYPE_CHECKING:
10408+
assert isinstance(group, str | Iterable)
10409+
group_iter: Iterable[Hashable] = (
10410+
(group,) if isinstance(group, str) else group
10411+
)
10412+
groupers = {g: UniqueGrouper() for g in group_iter}
10413+
1041010414
rgroupers = tuple(
1041110415
ResolvedGrouper(grouper, group, self)
1041210416
for group, grouper in groupers.items()

xarray/core/groupby.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454

5555
from xarray.core.dataarray import DataArray
5656
from xarray.core.dataset import Dataset
57-
from xarray.core.types import GroupIndex, GroupIndices, GroupKey
57+
from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey
5858
from xarray.core.utils import Frozen
5959
from xarray.groupers import EncodedGroups, Grouper
6060

@@ -319,6 +319,21 @@ def __len__(self) -> int:
319319
return len(self.encoded.full_index)
320320

321321

322+
def _validate_group_and_groupers(group: GroupInput, groupers: dict[str, Grouper]):
323+
if group is not None and groupers:
324+
raise ValueError(
325+
"Providing a combination of `group` and **groupers is not supported."
326+
)
327+
328+
if group is None and not groupers:
329+
raise ValueError("Either `group` or `**groupers` must be provided.")
330+
331+
if isinstance(group, np.ndarray | pd.Index):
332+
raise TypeError(
333+
f"`group` must be a DataArray. Received {type(group).__name__!r} instead"
334+
)
335+
336+
322337
def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
323338
# While we don't generally check the type of every arg, passing
324339
# multiple dimensions as multiple arguments is common enough, and the
@@ -327,7 +342,7 @@ def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
327342
# A future version could make squeeze kwarg only, but would face
328343
# backward-compat issues.
329344
if squeeze is not False:
330-
raise TypeError(f"`squeeze` must be False, but {squeeze} was supplied.")
345+
raise TypeError(f"`squeeze` must be False, but {squeeze!r} was supplied.")
331346

332347

333348
def _resolve_group(

xarray/core/types.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,17 @@
4242
from xarray.core.dataset import Dataset
4343
from xarray.core.indexes import Index, Indexes
4444
from xarray.core.utils import Frozen
45-
from xarray.core.variable import Variable
46-
from xarray.groupers import TimeResampler
45+
from xarray.core.variable import IndexVariable, Variable
46+
from xarray.groupers import Grouper, TimeResampler
47+
48+
GroupInput: TypeAlias = (
49+
str
50+
| DataArray
51+
| IndexVariable
52+
| Sequence[Hashable]
53+
| Mapping[Any, Grouper]
54+
| None
55+
)
4756

4857
try:
4958
from dask.array import Array as DaskArray

xarray/tests/test_groupby.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2635,6 +2635,31 @@ def test_weather_data_resample(use_flox):
26352635
assert expected.location.attrs == ds.location.attrs
26362636

26372637

2638+
@pytest.mark.parametrize("as_dataset", [True, False])
2639+
def test_multiple_groupers_string(as_dataset) -> None:
2640+
obj = DataArray(
2641+
np.array([1, 2, 3, 0, 2, np.nan]),
2642+
dims="d",
2643+
coords=dict(
2644+
labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])),
2645+
labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])),
2646+
),
2647+
name="foo",
2648+
)
2649+
2650+
if as_dataset:
2651+
obj = obj.to_dataset()
2652+
2653+
expected = obj.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper()).mean()
2654+
actual = obj.groupby(("labels1", "labels2")).mean()
2655+
assert_identical(expected, actual)
2656+
2657+
with pytest.raises(TypeError):
2658+
obj.groupby("labels1", "labels2")
2659+
with pytest.raises(ValueError):
2660+
obj.groupby("labels1", foo="bar")
2661+
2662+
26382663
@pytest.mark.parametrize("use_flox", [True, False])
26392664
def test_multiple_groupers(use_flox) -> None:
26402665
da = DataArray(

0 commit comments

Comments
 (0)