Skip to content

Commit ee18848

Browse files
committed
GroupBy(multiple groupers)
Closes pydata#924 Closes pydata#1056 Closes pydata#9332 xref pydata#324
1 parent 3c19231 commit ee18848

File tree

4 files changed

+100
-67
lines changed

4 files changed

+100
-67
lines changed

xarray/core/dataarray.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6801,27 +6801,21 @@ def groupby(
68016801
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
68026802
group = None
68036803

6804-
grouper: Grouper
68056804
if group is not None:
68066805
if groupers:
68076806
raise ValueError(
68086807
"Providing a combination of `group` and **groupers is not supported."
68096808
)
6810-
grouper = UniqueGrouper()
6809+
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
68116810
else:
6812-
if len(groupers) > 1:
6813-
raise ValueError("grouping by multiple variables is not supported yet.")
68146811
if not groupers:
68156812
raise ValueError("Either `group` or `**groupers` must be provided.")
6816-
group, grouper = next(iter(groupers.items()))
6817-
6818-
rgrouper = ResolvedGrouper(grouper, group, self)
6813+
rgroupers = tuple(
6814+
ResolvedGrouper(grouper, group, self)
6815+
for group, grouper in groupers.items()
6816+
)
68196817

6820-
return DataArrayGroupBy(
6821-
self,
6822-
(rgrouper,),
6823-
restore_coord_dims=restore_coord_dims,
6824-
)
6818+
return DataArrayGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims)
68256819

68266820
@_deprecate_positional_args("v2024.07.0")
68276821
def groupby_bins(

xarray/core/dataset.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10388,20 +10388,16 @@ def groupby(
1038810388
raise ValueError(
1038910389
"Providing a combination of `group` and **groupers is not supported."
1039010390
)
10391-
rgrouper = ResolvedGrouper(UniqueGrouper(), group, self)
10391+
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
1039210392
else:
10393-
if len(groupers) > 1:
10394-
raise ValueError("Grouping by multiple variables is not supported yet.")
10395-
elif not groupers:
10393+
if not groupers:
1039610394
raise ValueError("Either `group` or `**groupers` must be provided.")
10397-
for group, grouper in groupers.items():
10398-
rgrouper = ResolvedGrouper(grouper, group, self)
10395+
rgroupers = tuple(
10396+
ResolvedGrouper(grouper, group, self)
10397+
for group, grouper in groupers.items()
10398+
)
1039910399

10400-
return DatasetGroupBy(
10401-
self,
10402-
(rgrouper,),
10403-
restore_coord_dims=restore_coord_dims,
10404-
)
10400+
return DatasetGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims)
1040510401

1040610402
@_deprecate_positional_args("v2024.07.0")
1040710403
def groupby_bins(

xarray/core/groupby.py

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
import copy
4+
import functools
5+
import math
46
import warnings
57
from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
68
from dataclasses import dataclass, field
@@ -68,10 +70,11 @@ def check_reduce_dims(reduce_dims, dimensions):
6870
)
6971

7072

71-
def _codes_to_group_indices(inverse: np.ndarray, N: int) -> GroupIndices:
72-
assert inverse.ndim == 1
73+
def _codes_to_group_indices(codes: np.ndarray, N: int) -> GroupIndices:
74+
"""Converts integer codes for groups to group indices."""
75+
assert codes.ndim == 1
7376
groups: GroupIndices = tuple([] for _ in range(N))
74-
for n, g in enumerate(inverse):
77+
for n, g in enumerate(codes):
7578
if g >= 0:
7679
groups[g].append(n)
7780
return groups
@@ -448,7 +451,7 @@ class GroupBy(Generic[T_Xarray]):
448451
"_codes",
449452
)
450453
_obj: T_Xarray
451-
groupers: tuple[ResolvedGrouper]
454+
groupers: tuple[ResolvedGrouper, ...]
452455
_restore_coord_dims: bool
453456

454457
_original_obj: T_Xarray
@@ -464,7 +467,7 @@ class GroupBy(Generic[T_Xarray]):
464467
def __init__(
465468
self,
466469
obj: T_Xarray,
467-
groupers: tuple[ResolvedGrouper],
470+
groupers: tuple[ResolvedGrouper, ...],
468471
restore_coord_dims: bool = True,
469472
) -> None:
470473
"""Create a GroupBy object
@@ -483,16 +486,35 @@ def __init__(
483486

484487
self._original_obj = obj
485488

486-
(grouper,) = self.groupers
487-
self._original_group = grouper.group
489+
if len(groupers) > 1:
490+
for grouper in groupers:
491+
if grouper.group.ndim > 1:
492+
raise NotImplementedError(
493+
"Only grouping by multiple 1D variables is supported at the moment."
494+
)
495+
(grouper, *_) = self.groupers # FIXME
496+
self._original_group = grouper.group # FIXME
488497

489498
# specification for the groupby operation
490-
self._obj = grouper.stacked_obj
499+
self._obj = grouper.stacked_obj # FIXME
491500
self._restore_coord_dims = restore_coord_dims
492501

493-
# These should generalize to multiple groupers
494-
self._group_indices = grouper.group_indices
495-
self._codes = self._maybe_unstack(grouper.codes)
502+
self._shape = tuple(grouper.size for grouper in groupers)
503+
self._len = math.prod(self._shape)
504+
505+
self._codes = tuple(self._maybe_unstack(grouper.codes) for grouper in groupers)
506+
self._flatcodes = np.ravel_multi_index(self._codes, self._shape, mode="wrap")
507+
# NaNs; as well as values outside the bins are coded by -1
508+
# Restore these after the raveling
509+
mask = functools.reduce(np.logical_or, [(code == -1) for code in self._codes])
510+
self._flatcodes[mask] = -1
511+
512+
if len(groupers) == 1:
513+
# For ordered `group` we index into the array using slices.
514+
# Preserve this optimization when grouping by a single variable
515+
self._group_indices = self.groupers[0].group_indices
516+
else:
517+
self._group_indices = _codes_to_group_indices(self._flatcodes, self._len)
496518

497519
(self._group_dim,) = grouper.group1d.dims
498520
# cached attributes
@@ -566,13 +588,16 @@ def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]:
566588
return zip(grouper.unique_coord.data, self._iter_grouped())
567589

568590
def __repr__(self) -> str:
569-
(grouper,) = self.groupers
570-
return "{}, grouped over {!r}\n{!r} groups with labels {}.".format(
571-
self.__class__.__name__,
572-
grouper.name,
573-
grouper.full_index.size,
574-
", ".join(format_array_flat(grouper.full_index, 30).split()),
591+
text = (
592+
f"<{self.__class__.__name__}, "
593+
f"grouped over {len(self.groupers)} grouper(s),"
594+
f" {self._len} groups in total:"
575595
)
596+
for grouper in self.groupers:
597+
coord = grouper.unique_coord
598+
labels = ", ".join(format_array_flat(coord, 30).split())
599+
text += f"\n\t{grouper.name!r}: {coord.size} groups with labels {labels}"
600+
return text + ">"
576601

577602
def _iter_grouped(self) -> Iterator[T_Xarray]:
578603
"""Iterate over each element in this group"""
@@ -609,7 +634,7 @@ def _binary_op(self, other, f, reflexive=False):
609634
obj = self._original_obj
610635
name = grouper.name
611636
group = grouper.group
612-
codes = self._codes
637+
(codes,) = self._codes
613638
dims = group.dims
614639

615640
if isinstance(group, _DummyGroup):
@@ -709,15 +734,16 @@ def _maybe_restore_empty_groups(self, combined):
709734
def _maybe_unstack(self, obj):
710735
"""This gets called if we are applying on an array with a
711736
multidimensional group."""
712-
(grouper,) = self.groupers
713-
stacked_dim = grouper.stacked_dim
714-
inserted_dims = grouper.inserted_dims
715-
if stacked_dim is not None and stacked_dim in obj.dims:
716-
obj = obj.unstack(stacked_dim)
717-
for dim in inserted_dims:
718-
if dim in obj.coords:
719-
del obj.coords[dim]
720-
obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords))
737+
# TODO: Is this really right?
738+
for grouper in self.groupers:
739+
stacked_dim = grouper.stacked_dim
740+
if stacked_dim is not None and stacked_dim in obj.dims:
741+
inserted_dims = grouper.inserted_dims
742+
obj = obj.unstack(stacked_dim)
743+
for dim in inserted_dims:
744+
if dim in obj.coords:
745+
del obj.coords[dim]
746+
obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords))
721747
return obj
722748

723749
def _flox_reduce(
@@ -1115,20 +1141,21 @@ def _concat_shortcut(self, applied, dim, positions=None):
11151141
return self._obj._replace_maybe_drop_dims(reordered)
11161142

11171143
def _restore_dim_order(self, stacked: DataArray) -> DataArray:
1118-
(grouper,) = self.groupers
1119-
group = grouper.group1d
1120-
11211144
def lookup_order(dimension):
1122-
if dimension == grouper.name:
1123-
(dimension,) = group.dims
1145+
for grouper in self.groupers:
1146+
if dimension == grouper.name and grouper.group.ndim == 1:
1147+
(dimension,) = grouper.group.dims
11241148
if dimension in self._obj.dims:
11251149
axis = self._obj.get_axis_num(dimension)
11261150
else:
11271151
axis = 1e6 # some arbitrarily high value
11281152
return axis
11291153

11301154
new_order = sorted(stacked.dims, key=lookup_order)
1131-
return stacked.transpose(*new_order, transpose_coords=self._restore_coord_dims)
1155+
stacked = stacked.transpose(
1156+
*new_order, transpose_coords=self._restore_coord_dims
1157+
)
1158+
return stacked
11321159

11331160
def map(
11341161
self,

xarray/tests/test_groupby.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -556,27 +556,28 @@ def test_da_groupby_assign_coords() -> None:
556556
@pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")])
557557
def test_groupby_repr(obj, dim) -> None:
558558
actual = repr(obj.groupby(dim))
559-
expected = f"{obj.__class__.__name__}GroupBy"
560-
expected += f", grouped over {dim!r}"
561-
expected += f"\n{len(np.unique(obj[dim]))!r} groups with labels "
559+
N = len(np.unique(obj[dim]))
560+
expected = f"<{obj.__class__.__name__}GroupBy"
561+
expected += f", grouped over 1 grouper(s), {N} groups in total:"
562+
expected += f"\n\t{dim!r}: {N} groups with labels "
562563
if dim == "x":
563-
expected += "1, 2, 3, 4, 5."
564+
expected += "1, 2, 3, 4, 5>"
564565
elif dim == "y":
565-
expected += "0, 1, 2, 3, 4, 5, ..., 15, 16, 17, 18, 19."
566+
expected += "0, 1, 2, 3, 4, 5, ..., 15, 16, 17, 18, 19>"
566567
elif dim == "z":
567-
expected += "'a', 'b', 'c'."
568+
expected += "'a', 'b', 'c'>"
568569
elif dim == "month":
569-
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12."
570+
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>"
570571
assert actual == expected
571572

572573

573574
@pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")])
574575
def test_groupby_repr_datetime(obj) -> None:
575576
actual = repr(obj.groupby("t.month"))
576-
expected = f"{obj.__class__.__name__}GroupBy"
577-
expected += ", grouped over 'month'"
578-
expected += f"\n{len(np.unique(obj.t.dt.month))!r} groups with labels "
579-
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12."
577+
expected = f"<{obj.__class__.__name__}GroupBy"
578+
expected += ", grouped over 1 grouper(s), 12 groups in total:\n"
579+
expected += "\t'month': 12 groups with labels "
580+
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>"
580581
assert actual == expected
581582

582583

@@ -2561,3 +2562,18 @@ def factorize(self, group) -> EncodedGroups:
25612562
obj.groupby("time.year", time=YearGrouper())
25622563
with pytest.raises(ValueError):
25632564
obj.groupby()
2565+
2566+
2567+
def test_multiple_groupers() -> None:
2568+
da = xr.DataArray(
2569+
np.array([1, 2, 3, 0, 2, np.nan]),
2570+
dims="d",
2571+
coords=dict(
2572+
labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])),
2573+
labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])),
2574+
),
2575+
)
2576+
2577+
gb = da.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper())
2578+
repr(gb)
2579+
gb.mean()

0 commit comments

Comments
 (0)