@@ -73,6 +73,21 @@ def check_reduce_dims(reduce_dims, dimensions):
73
73
)
74
74
75
75
76
+ def _maybe_squeeze_indices (
77
+ indices , squeeze : bool | None , grouper : ResolvedGrouper , warn : bool
78
+ ):
79
+ if squeeze in [None , True ] and grouper .can_squeeze :
80
+ if squeeze is None and warn :
81
+ emit_user_level_warning (
82
+ "The `squeeze` kwarg to GroupBy is being removed."
83
+ "Pass .groupby(..., squeeze=False) to silence this warning."
84
+ )
85
+ if isinstance (indices , slice ):
86
+ assert indices .stop - indices .start == 1
87
+ indices = indices .start
88
+ return indices
89
+
90
+
76
91
def unique_value_groups (
77
92
ar , sort : bool = True
78
93
) -> tuple [np .ndarray | pd .Index , T_GroupIndices , np .ndarray ]:
@@ -366,10 +381,10 @@ def dims(self):
366
381
return self .group1d .dims
367
382
368
383
@abstractmethod
369
- def _factorize (self , squeeze : bool ) -> T_FactorizeOut :
384
+ def _factorize (self ) -> T_FactorizeOut :
370
385
raise NotImplementedError
371
386
372
- def factorize (self , squeeze : bool ) -> None :
387
+ def factorize (self ) -> None :
373
388
# This design makes it clear to mypy that
374
389
# codes, group_indices, unique_coord, and full_index
375
390
# are set by the factorize method on the derived class.
@@ -378,7 +393,7 @@ def factorize(self, squeeze: bool) -> None:
378
393
self .group_indices ,
379
394
self .unique_coord ,
380
395
self .full_index ,
381
- ) = self ._factorize (squeeze )
396
+ ) = self ._factorize ()
382
397
383
398
@property
384
399
def is_unique_and_monotonic (self ) -> bool :
@@ -393,15 +408,19 @@ def group_as_index(self) -> pd.Index:
393
408
self ._group_as_index = self .group1d .to_index ()
394
409
return self ._group_as_index
395
410
411
+ @property
412
+ def can_squeeze (self ):
413
+ is_dimension = self .group .dims == (self .group .name ,)
414
+ return is_dimension and self .is_unique_and_monotonic
415
+
396
416
397
417
@dataclass
398
418
class ResolvedUniqueGrouper (ResolvedGrouper ):
399
419
grouper : UniqueGrouper
400
420
401
- def _factorize (self , squeeze ) -> T_FactorizeOut :
402
- is_dimension = self .group .dims == (self .group .name ,)
403
- if is_dimension and self .is_unique_and_monotonic :
404
- return self ._factorize_dummy (squeeze )
421
+ def factorize (self ) -> T_FactorizeOut :
422
+ if self .can_squeeze :
423
+ return self ._factorize_dummy ()
405
424
else :
406
425
return self ._factorize_unique ()
407
426
@@ -424,15 +443,12 @@ def _factorize_unique(self) -> T_FactorizeOut:
424
443
425
444
return codes , group_indices , unique_coord , full_index
426
445
427
- def _factorize_dummy (self , squeeze ) -> T_FactorizeOut :
446
+ def _factorize_dummy (self ) -> T_FactorizeOut :
428
447
size = self .group .size
429
448
# no need to factorize
430
- if not squeeze :
431
- # use slices to do views instead of fancy indexing
432
- # equivalent to: group_indices = group_indices.reshape(-1, 1)
433
- group_indices : T_GroupIndices = [slice (i , i + 1 ) for i in range (size )]
434
- else :
435
- group_indices = list (range (size ))
449
+ # use slices to do views instead of fancy indexing
450
+ # equivalent to: group_indices = group_indices.reshape(-1, 1)
451
+ group_indices : T_GroupIndices = [slice (i , i + 1 ) for i in range (size )]
436
452
size_range = np .arange (size )
437
453
if isinstance (self .group , _DummyGroup ):
438
454
codes = self .group .to_dataarray ().copy (data = size_range )
@@ -448,7 +464,7 @@ def _factorize_dummy(self, squeeze) -> T_FactorizeOut:
448
464
class ResolvedBinGrouper (ResolvedGrouper ):
449
465
grouper : BinGrouper
450
466
451
- def _factorize (self , squeeze : bool ) -> T_FactorizeOut :
467
+ def factorize (self ) -> T_FactorizeOut :
452
468
from xarray .core .dataarray import DataArray
453
469
454
470
data = self .group1d .values
@@ -546,7 +562,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
546
562
_apply_loffset (self .grouper .loffset , first_items )
547
563
return first_items , codes
548
564
549
- def _factorize (self , squeeze : bool ) -> T_FactorizeOut :
565
+ def factorize (self ) -> T_FactorizeOut :
550
566
full_index , first_items , codes_ = self ._get_index_and_items ()
551
567
sbins = first_items .values .astype (np .int64 )
552
568
group_indices : T_GroupIndices = [
@@ -591,14 +607,14 @@ class TimeResampleGrouper(Grouper):
591
607
loffset : datetime .timedelta | str | None
592
608
593
609
594
- def _validate_groupby_squeeze (squeeze : bool ) -> None :
610
+ def _validate_groupby_squeeze (squeeze : bool | None ) -> None :
595
611
# While we don't generally check the type of every arg, passing
596
612
# multiple dimensions as multiple arguments is common enough, and the
597
613
# consequences hidden enough (strings evaluate as true) to warrant
598
614
# checking here.
599
615
# A future version could make squeeze kwarg only, but would face
600
616
# backward-compat issues.
601
- if not isinstance (squeeze , bool ):
617
+ if squeeze is not None and not isinstance (squeeze , bool ):
602
618
raise TypeError (f"`squeeze` must be True or False, but { squeeze } was supplied" )
603
619
604
620
@@ -730,7 +746,7 @@ def __init__(
730
746
self ._original_obj = obj
731
747
732
748
for grouper_ in self .groupers :
733
- grouper_ .factorize ( squeeze )
749
+ grouper_ ._factorize ( )
734
750
735
751
(grouper ,) = self .groupers
736
752
self ._original_group = grouper .group
@@ -762,9 +778,14 @@ def sizes(self) -> Mapping[Hashable, int]:
762
778
Dataset.sizes
763
779
"""
764
780
if self ._sizes is None :
765
- self ._sizes = self ._obj .isel (
766
- {self ._group_dim : self ._group_indices [0 ]}
767
- ).sizes
781
+ (grouper ,) = self .groupers
782
+ index = _maybe_squeeze_indices (
783
+ self ._group_indices [0 ],
784
+ self ._squeeze ,
785
+ grouper ,
786
+ warn = True ,
787
+ )
788
+ self ._sizes = self ._obj .isel ({self ._group_dim : index }).sizes
768
789
769
790
return self ._sizes
770
791
@@ -798,14 +819,22 @@ def groups(self) -> dict[GroupKey, GroupIndex]:
798
819
# provided to mimic pandas.groupby
799
820
if self ._groups is None :
800
821
(grouper ,) = self .groupers
801
- self ._groups = dict (zip (grouper .unique_coord .values , self ._group_indices ))
822
+ squeezed_indices = (
823
+ _maybe_squeeze_indices (ind , self ._squeeze , grouper , warn = idx > 0 )
824
+ for idx , ind in enumerate (self ._group_indices )
825
+ )
826
+ self ._groups = dict (zip (grouper .unique_coord .values , squeezed_indices ))
802
827
return self ._groups
803
828
804
829
def __getitem__ (self , key : GroupKey ) -> T_Xarray :
805
830
"""
806
831
Get DataArray or Dataset corresponding to a particular group label.
807
832
"""
808
- return self ._obj .isel ({self ._group_dim : self .groups [key ]})
833
+ (grouper ,) = self .groupers
834
+ index = _maybe_squeeze_indices (
835
+ self .groups [key ], self ._squeeze , grouper , warn = True
836
+ )
837
+ return self ._obj .isel ({self ._group_dim : index })
809
838
810
839
def __len__ (self ) -> int :
811
840
(grouper ,) = self .groupers
@@ -826,7 +855,11 @@ def __repr__(self) -> str:
826
855
827
856
def _iter_grouped (self ) -> Iterator [T_Xarray ]:
828
857
"""Iterate over each element in this group"""
829
- for indices in self ._group_indices :
858
+ (grouper ,) = self .groupers
859
+ for idx , indices in enumerate (self ._group_indices ):
860
+ indices = _maybe_squeeze_indices (
861
+ indices , self ._squeeze , grouper , warn = idx > 0
862
+ )
830
863
yield self ._obj .isel ({self ._group_dim : indices })
831
864
832
865
def _infer_concat_args (self , applied_example ):
@@ -1309,7 +1342,11 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic):
1309
1342
@property
1310
1343
def dims (self ) -> tuple [Hashable , ...]:
1311
1344
if self ._dims is None :
1312
- self ._dims = self ._obj .isel ({self ._group_dim : self ._group_indices [0 ]}).dims
1345
+ (grouper ,) = self .groupers
1346
+ index = _maybe_squeeze_indices (
1347
+ self ._group_indices [0 ], self ._squeeze , grouper , warn = True
1348
+ )
1349
+ self ._dims = self ._obj .isel ({self ._group_dim : index }).dims
1313
1350
1314
1351
return self ._dims
1315
1352
@@ -1318,7 +1355,11 @@ def _iter_grouped_shortcut(self):
1318
1355
metadata
1319
1356
"""
1320
1357
var = self ._obj .variable
1321
- for indices in self ._group_indices :
1358
+ (grouper ,) = self .groupers
1359
+ for idx , indices in enumerate (self ._group_indices ):
1360
+ indices = _maybe_squeeze_indices (
1361
+ indices , self ._squeeze , grouper , warn = idx > 0
1362
+ )
1322
1363
yield var [{self ._group_dim : indices }]
1323
1364
1324
1365
def _concat_shortcut (self , applied , dim , positions = None ):
@@ -1517,7 +1558,14 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic):
1517
1558
@property
1518
1559
def dims (self ) -> Frozen [Hashable , int ]:
1519
1560
if self ._dims is None :
1520
- self ._dims = self ._obj .isel ({self ._group_dim : self ._group_indices [0 ]}).dims
1561
+ (grouper ,) = self .groupers
1562
+ index = _maybe_squeeze_indices (
1563
+ self ._group_indices [0 ],
1564
+ self ._squeeze ,
1565
+ grouper ,
1566
+ warn = True ,
1567
+ )
1568
+ self ._dims = self ._obj .isel ({self ._group_dim : index }).dims
1521
1569
1522
1570
return self ._dims
1523
1571
0 commit comments