Skip to content

Commit 09ab13c

Browse files
committed
Deprecate squeeze in GroupBy.
Closes pydata#2157
1 parent 5213f0d commit 09ab13c

File tree

4 files changed

+115
-59
lines changed

4 files changed

+115
-59
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ Breaking changes
4343

4444
Deprecations
4545
~~~~~~~~~~~~
46+
- The `squeeze` kwarg to GroupBy is now deprecated. (:issue:`2157`)
47+
By `Deepak Cherian <https://github.com/dcherian>`_.
4648

4749
- As part of an effort to standardize the API, we're renaming the ``dims``
4850
keyword arg to ``dim`` for the minority of functions which current use

xarray/core/dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6620,7 +6620,7 @@ def interp_calendar(
66206620
def groupby(
66216621
self,
66226622
group: Hashable | DataArray | IndexVariable,
6623-
squeeze: bool = True,
6623+
squeeze: bool | None = None,
66246624
restore_coord_dims: bool = False,
66256625
) -> DataArrayGroupBy:
66266626
"""Returns a DataArrayGroupBy object for performing grouped operations.

xarray/core/groupby.py

Lines changed: 76 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,21 @@ def check_reduce_dims(reduce_dims, dimensions):
7373
)
7474

7575

76+
def _maybe_squeeze_indices(
77+
indices, squeeze: bool | None, grouper: ResolvedGrouper, warn: bool
78+
):
79+
if squeeze in [None, True] and grouper.can_squeeze:
80+
if squeeze is None and warn:
81+
emit_user_level_warning(
82+
"The `squeeze` kwarg to GroupBy is being removed."
83+
"Pass .groupby(..., squeeze=False) to silence this warning."
84+
)
85+
if isinstance(indices, slice):
86+
assert indices.stop - indices.start == 1
87+
indices = indices.start
88+
return indices
89+
90+
7691
def unique_value_groups(
7792
ar, sort: bool = True
7893
) -> tuple[np.ndarray | pd.Index, T_GroupIndices, np.ndarray]:
@@ -366,10 +381,10 @@ def dims(self):
366381
return self.group1d.dims
367382

368383
@abstractmethod
369-
def _factorize(self, squeeze: bool) -> T_FactorizeOut:
384+
def _factorize(self) -> T_FactorizeOut:
370385
raise NotImplementedError
371386

372-
def factorize(self, squeeze: bool) -> None:
387+
def factorize(self) -> None:
373388
# This design makes it clear to mypy that
374389
# codes, group_indices, unique_coord, and full_index
375390
# are set by the factorize method on the derived class.
@@ -378,7 +393,7 @@ def factorize(self, squeeze: bool) -> None:
378393
self.group_indices,
379394
self.unique_coord,
380395
self.full_index,
381-
) = self._factorize(squeeze)
396+
) = self._factorize()
382397

383398
@property
384399
def is_unique_and_monotonic(self) -> bool:
@@ -393,15 +408,19 @@ def group_as_index(self) -> pd.Index:
393408
self._group_as_index = self.group1d.to_index()
394409
return self._group_as_index
395410

411+
@property
412+
def can_squeeze(self):
413+
is_dimension = self.group.dims == (self.group.name,)
414+
return is_dimension and self.is_unique_and_monotonic
415+
396416

397417
@dataclass
398418
class ResolvedUniqueGrouper(ResolvedGrouper):
399419
grouper: UniqueGrouper
400420

401-
def _factorize(self, squeeze) -> T_FactorizeOut:
402-
is_dimension = self.group.dims == (self.group.name,)
403-
if is_dimension and self.is_unique_and_monotonic:
404-
return self._factorize_dummy(squeeze)
421+
def factorize(self) -> T_FactorizeOut:
422+
if self.can_squeeze:
423+
return self._factorize_dummy()
405424
else:
406425
return self._factorize_unique()
407426

@@ -424,15 +443,12 @@ def _factorize_unique(self) -> T_FactorizeOut:
424443

425444
return codes, group_indices, unique_coord, full_index
426445

427-
def _factorize_dummy(self, squeeze) -> T_FactorizeOut:
446+
def _factorize_dummy(self) -> T_FactorizeOut:
428447
size = self.group.size
429448
# no need to factorize
430-
if not squeeze:
431-
# use slices to do views instead of fancy indexing
432-
# equivalent to: group_indices = group_indices.reshape(-1, 1)
433-
group_indices: T_GroupIndices = [slice(i, i + 1) for i in range(size)]
434-
else:
435-
group_indices = list(range(size))
449+
# use slices to do views instead of fancy indexing
450+
# equivalent to: group_indices = group_indices.reshape(-1, 1)
451+
group_indices: T_GroupIndices = [slice(i, i + 1) for i in range(size)]
436452
size_range = np.arange(size)
437453
if isinstance(self.group, _DummyGroup):
438454
codes = self.group.to_dataarray().copy(data=size_range)
@@ -448,7 +464,7 @@ def _factorize_dummy(self, squeeze) -> T_FactorizeOut:
448464
class ResolvedBinGrouper(ResolvedGrouper):
449465
grouper: BinGrouper
450466

451-
def _factorize(self, squeeze: bool) -> T_FactorizeOut:
467+
def factorize(self) -> T_FactorizeOut:
452468
from xarray.core.dataarray import DataArray
453469

454470
data = self.group1d.values
@@ -546,7 +562,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
546562
_apply_loffset(self.grouper.loffset, first_items)
547563
return first_items, codes
548564

549-
def _factorize(self, squeeze: bool) -> T_FactorizeOut:
565+
def factorize(self) -> T_FactorizeOut:
550566
full_index, first_items, codes_ = self._get_index_and_items()
551567
sbins = first_items.values.astype(np.int64)
552568
group_indices: T_GroupIndices = [
@@ -591,14 +607,14 @@ class TimeResampleGrouper(Grouper):
591607
loffset: datetime.timedelta | str | None
592608

593609

594-
def _validate_groupby_squeeze(squeeze: bool) -> None:
610+
def _validate_groupby_squeeze(squeeze: bool | None) -> None:
595611
# While we don't generally check the type of every arg, passing
596612
# multiple dimensions as multiple arguments is common enough, and the
597613
# consequences hidden enough (strings evaluate as true) to warrant
598614
# checking here.
599615
# A future version could make squeeze kwarg only, but would face
600616
# backward-compat issues.
601-
if not isinstance(squeeze, bool):
617+
if squeeze is not None and not isinstance(squeeze, bool):
602618
raise TypeError(f"`squeeze` must be True or False, but {squeeze} was supplied")
603619

604620

@@ -730,7 +746,7 @@ def __init__(
730746
self._original_obj = obj
731747

732748
for grouper_ in self.groupers:
733-
grouper_.factorize(squeeze)
749+
grouper_._factorize()
734750

735751
(grouper,) = self.groupers
736752
self._original_group = grouper.group
@@ -762,9 +778,14 @@ def sizes(self) -> Mapping[Hashable, int]:
762778
Dataset.sizes
763779
"""
764780
if self._sizes is None:
765-
self._sizes = self._obj.isel(
766-
{self._group_dim: self._group_indices[0]}
767-
).sizes
781+
(grouper,) = self.groupers
782+
index = _maybe_squeeze_indices(
783+
self._group_indices[0],
784+
self._squeeze,
785+
grouper,
786+
warn=True,
787+
)
788+
self._sizes = self._obj.isel({self._group_dim: index}).sizes
768789

769790
return self._sizes
770791

@@ -798,14 +819,22 @@ def groups(self) -> dict[GroupKey, GroupIndex]:
798819
# provided to mimic pandas.groupby
799820
if self._groups is None:
800821
(grouper,) = self.groupers
801-
self._groups = dict(zip(grouper.unique_coord.values, self._group_indices))
822+
squeezed_indices = (
823+
_maybe_squeeze_indices(ind, self._squeeze, grouper, warn=idx > 0)
824+
for idx, ind in enumerate(self._group_indices)
825+
)
826+
self._groups = dict(zip(grouper.unique_coord.values, squeezed_indices))
802827
return self._groups
803828

804829
def __getitem__(self, key: GroupKey) -> T_Xarray:
805830
"""
806831
Get DataArray or Dataset corresponding to a particular group label.
807832
"""
808-
return self._obj.isel({self._group_dim: self.groups[key]})
833+
(grouper,) = self.groupers
834+
index = _maybe_squeeze_indices(
835+
self.groups[key], self._squeeze, grouper, warn=True
836+
)
837+
return self._obj.isel({self._group_dim: index})
809838

810839
def __len__(self) -> int:
811840
(grouper,) = self.groupers
@@ -826,7 +855,11 @@ def __repr__(self) -> str:
826855

827856
def _iter_grouped(self) -> Iterator[T_Xarray]:
828857
"""Iterate over each element in this group"""
829-
for indices in self._group_indices:
858+
(grouper,) = self.groupers
859+
for idx, indices in enumerate(self._group_indices):
860+
indices = _maybe_squeeze_indices(
861+
indices, self._squeeze, grouper, warn=idx > 0
862+
)
830863
yield self._obj.isel({self._group_dim: indices})
831864

832865
def _infer_concat_args(self, applied_example):
@@ -1309,7 +1342,11 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic):
13091342
@property
13101343
def dims(self) -> tuple[Hashable, ...]:
13111344
if self._dims is None:
1312-
self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims
1345+
(grouper,) = self.groupers
1346+
index = _maybe_squeeze_indices(
1347+
self._group_indices[0], self._squeeze, grouper, warn=True
1348+
)
1349+
self._dims = self._obj.isel({self._group_dim: index}).dims
13131350

13141351
return self._dims
13151352

@@ -1318,7 +1355,11 @@ def _iter_grouped_shortcut(self):
13181355
metadata
13191356
"""
13201357
var = self._obj.variable
1321-
for indices in self._group_indices:
1358+
(grouper,) = self.groupers
1359+
for idx, indices in enumerate(self._group_indices):
1360+
indices = _maybe_squeeze_indices(
1361+
indices, self._squeeze, grouper, warn=idx > 0
1362+
)
13221363
yield var[{self._group_dim: indices}]
13231364

13241365
def _concat_shortcut(self, applied, dim, positions=None):
@@ -1517,7 +1558,14 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic):
15171558
@property
15181559
def dims(self) -> Frozen[Hashable, int]:
15191560
if self._dims is None:
1520-
self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims
1561+
(grouper,) = self.groupers
1562+
index = _maybe_squeeze_indices(
1563+
self._group_indices[0],
1564+
self._squeeze,
1565+
grouper,
1566+
warn=True,
1567+
)
1568+
self._dims = self._obj.isel({self._group_dim: index}).dims
15211569

15221570
return self._dims
15231571

0 commit comments

Comments
 (0)