@@ -661,18 +661,26 @@ def __init__(
661
661
# specification for the groupby operation
662
662
# TODO: handle obj having variables that are not present on any of the groupers
663
663
# simple broadcasting fails for ExtensionArrays.
664
- # FIXME: Skip this stacking when grouping by a dask array, it's useless in that case.
665
- (self .group1d , self ._obj , self ._stacked_dim , self ._inserted_dims ) = _ensure_1d (
666
- group = self .encoded .codes , obj = obj
667
- )
668
- (self ._group_dim ,) = self .group1d .dims
664
+ codes = self .encoded .codes
665
+ self ._by_chunked = is_chunked_array (codes ._variable ._data )
666
+ if not self ._by_chunked :
667
+ (self .group1d , self ._obj , self ._stacked_dim , self ._inserted_dims ) = (
668
+ _ensure_1d (group = codes , obj = obj )
669
+ )
670
+ (self ._group_dim ,) = self .group1d .dims
671
+ else :
672
+ self .group1d = None
673
+ # This transpose preserves dim order behaviour
674
+ self ._obj = obj .transpose (..., * codes .dims )
675
+ self ._stacked_dim = None
676
+ self ._inserted_dims = []
677
+ self ._group_dim = None
669
678
670
679
# cached attributes
671
680
self ._groups = None
672
681
self ._dims = None
673
682
self ._sizes = None
674
683
self ._len = len (self .encoded .full_index )
675
- self ._by_chunked = is_chunked_array (self .encoded .codes .data )
676
684
677
685
@property
678
686
def sizes (self ) -> Mapping [Hashable , int ]:
@@ -817,6 +825,7 @@ def __getitem__(self, key: GroupKey) -> T_Xarray:
817
825
"""
818
826
Get DataArray or Dataset corresponding to a particular group label.
819
827
"""
828
+ self ._raise_if_by_is_chunked ()
820
829
return self ._obj .isel ({self ._group_dim : self .groups [key ]})
821
830
822
831
def __len__ (self ) -> int :
@@ -1331,9 +1340,6 @@ def quantile(
1331
1340
"Sample quantiles in statistical packages,"
1332
1341
The American Statistician, 50(4), pp. 361-365, 1996
1333
1342
"""
1334
- if dim is None :
1335
- dim = (self ._group_dim ,)
1336
-
1337
1343
# Dataset.quantile does this, do it for flox to ensure same output.
1338
1344
q = np .asarray (q , dtype = np .float64 )
1339
1345
@@ -1348,6 +1354,8 @@ def quantile(
1348
1354
)
1349
1355
return result
1350
1356
else :
1357
+ if dim is None :
1358
+ dim = (self ._group_dim ,)
1351
1359
return self .map (
1352
1360
self ._obj .__class__ .quantile ,
1353
1361
shortcut = False ,
@@ -1491,6 +1499,7 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic):
1491
1499
1492
1500
@property
1493
1501
def dims (self ) -> tuple [Hashable , ...]:
1502
+ self ._raise_if_by_is_chunked ()
1494
1503
if self ._dims is None :
1495
1504
index = self .encoded .group_indices [0 ]
1496
1505
self ._dims = self ._obj .isel ({self ._group_dim : index }).dims
@@ -1702,6 +1711,7 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic):
1702
1711
1703
1712
@property
1704
1713
def dims (self ) -> Frozen [Hashable , int ]:
1714
+ self ._raise_if_by_is_chunked ()
1705
1715
if self ._dims is None :
1706
1716
index = self .encoded .group_indices [0 ]
1707
1717
self ._dims = self ._obj .isel ({self ._group_dim : index }).dims
0 commit comments