1
1
from __future__ import annotations
2
2
3
3
import copy
4
+ import functools
5
+ import math
4
6
import warnings
5
7
from collections .abc import Callable , Hashable , Iterator , Mapping , Sequence
6
8
from dataclasses import dataclass , field
@@ -68,10 +70,11 @@ def check_reduce_dims(reduce_dims, dimensions):
68
70
)
69
71
70
72
71
- def _codes_to_group_indices (inverse : np .ndarray , N : int ) -> GroupIndices :
72
- assert inverse .ndim == 1
73
+ def _codes_to_group_indices (codes : np .ndarray , N : int ) -> GroupIndices :
74
+ """Converts integer codes for groups to group indices."""
75
+ assert codes .ndim == 1
73
76
groups : GroupIndices = tuple ([] for _ in range (N ))
74
- for n , g in enumerate (inverse ):
77
+ for n , g in enumerate (codes ):
75
78
if g >= 0 :
76
79
groups [g ].append (n )
77
80
return groups
@@ -448,7 +451,7 @@ class GroupBy(Generic[T_Xarray]):
448
451
"_codes" ,
449
452
)
450
453
_obj : T_Xarray
451
- groupers : tuple [ResolvedGrouper ]
454
+ groupers : tuple [ResolvedGrouper , ... ]
452
455
_restore_coord_dims : bool
453
456
454
457
_original_obj : T_Xarray
@@ -464,7 +467,7 @@ class GroupBy(Generic[T_Xarray]):
464
467
def __init__ (
465
468
self ,
466
469
obj : T_Xarray ,
467
- groupers : tuple [ResolvedGrouper ],
470
+ groupers : tuple [ResolvedGrouper , ... ],
468
471
restore_coord_dims : bool = True ,
469
472
) -> None :
470
473
"""Create a GroupBy object
@@ -483,16 +486,35 @@ def __init__(
483
486
484
487
self ._original_obj = obj
485
488
486
- (grouper ,) = self .groupers
487
- self ._original_group = grouper .group
489
+ if len (groupers ) > 1 :
490
+ for grouper in groupers :
491
+ if grouper .group .ndim > 1 :
492
+ raise NotImplementedError (
493
+ "Only grouping by multiple 1D variables is supported at the moment."
494
+ )
495
+ (grouper , * _ ) = self .groupers # FIXME
496
+ self ._original_group = grouper .group # FIXME
488
497
489
498
# specification for the groupby operation
490
- self ._obj = grouper .stacked_obj
499
+ self ._obj = grouper .stacked_obj # FIXME
491
500
self ._restore_coord_dims = restore_coord_dims
492
501
493
- # These should generalize to multiple groupers
494
- self ._group_indices = grouper .group_indices
495
- self ._codes = self ._maybe_unstack (grouper .codes )
502
+ self ._shape = tuple (grouper .size for grouper in groupers )
503
+ self ._len = math .prod (self ._shape )
504
+
505
+ self ._codes = tuple (self ._maybe_unstack (grouper .codes ) for grouper in groupers )
506
+ self ._flatcodes = np .ravel_multi_index (self ._codes , self ._shape , mode = "wrap" )
507
+ # NaNs; as well as values outside the bins are coded by -1
508
+ # Restore these after the raveling
509
+ mask = functools .reduce (np .logical_or , [(code == - 1 ) for code in self ._codes ])
510
+ self ._flatcodes [mask ] = - 1
511
+
512
+ if len (groupers ) == 1 :
513
+ # For ordered `group` we index into the array using slices.
514
+ # Preserve this optimization when grouping by a single variable
515
+ self ._group_indices = self .groupers [0 ].group_indices
516
+ else :
517
+ self ._group_indices = _codes_to_group_indices (self ._flatcodes , self ._len )
496
518
497
519
(self ._group_dim ,) = grouper .group1d .dims
498
520
# cached attributes
@@ -566,13 +588,16 @@ def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]:
566
588
return zip (grouper .unique_coord .data , self ._iter_grouped ())
567
589
568
590
def __repr__ (self ) -> str :
569
- (grouper ,) = self .groupers
570
- return "{}, grouped over {!r}\n {!r} groups with labels {}." .format (
571
- self .__class__ .__name__ ,
572
- grouper .name ,
573
- grouper .full_index .size ,
574
- ", " .join (format_array_flat (grouper .full_index , 30 ).split ()),
591
+ text = (
592
+ f"<{ self .__class__ .__name__ } , "
593
+ f"grouped over { len (self .groupers )} grouper(s),"
594
+ f" { self ._len } groups in total:"
575
595
)
596
+ for grouper in self .groupers :
597
+ coord = grouper .unique_coord
598
+ labels = ", " .join (format_array_flat (coord , 30 ).split ())
599
+ text += f"\n \t { grouper .name !r} : { coord .size } groups with labels { labels } "
600
+ return text + ">"
576
601
577
602
def _iter_grouped (self ) -> Iterator [T_Xarray ]:
578
603
"""Iterate over each element in this group"""
@@ -609,7 +634,7 @@ def _binary_op(self, other, f, reflexive=False):
609
634
obj = self ._original_obj
610
635
name = grouper .name
611
636
group = grouper .group
612
- codes = self ._codes
637
+ ( codes ,) = self ._codes
613
638
dims = group .dims
614
639
615
640
if isinstance (group , _DummyGroup ):
@@ -709,15 +734,16 @@ def _maybe_restore_empty_groups(self, combined):
709
734
def _maybe_unstack (self , obj ):
710
735
"""This gets called if we are applying on an array with a
711
736
multidimensional group."""
712
- (grouper ,) = self .groupers
713
- stacked_dim = grouper .stacked_dim
714
- inserted_dims = grouper .inserted_dims
715
- if stacked_dim is not None and stacked_dim in obj .dims :
716
- obj = obj .unstack (stacked_dim )
717
- for dim in inserted_dims :
718
- if dim in obj .coords :
719
- del obj .coords [dim ]
720
- obj ._indexes = filter_indexes_from_coords (obj ._indexes , set (obj .coords ))
737
+ # TODO: Is this really right?
738
+ for grouper in self .groupers :
739
+ stacked_dim = grouper .stacked_dim
740
+ if stacked_dim is not None and stacked_dim in obj .dims :
741
+ inserted_dims = grouper .inserted_dims
742
+ obj = obj .unstack (stacked_dim )
743
+ for dim in inserted_dims :
744
+ if dim in obj .coords :
745
+ del obj .coords [dim ]
746
+ obj ._indexes = filter_indexes_from_coords (obj ._indexes , set (obj .coords ))
721
747
return obj
722
748
723
749
def _flox_reduce (
@@ -1115,20 +1141,21 @@ def _concat_shortcut(self, applied, dim, positions=None):
1115
1141
return self ._obj ._replace_maybe_drop_dims (reordered )
1116
1142
1117
1143
def _restore_dim_order (self , stacked : DataArray ) -> DataArray :
1118
- (grouper ,) = self .groupers
1119
- group = grouper .group1d
1120
-
1121
1144
def lookup_order (dimension ):
1122
- if dimension == grouper .name :
1123
- (dimension ,) = group .dims
1145
+ for grouper in self .groupers :
1146
+ if dimension == grouper .name and grouper .group .ndim == 1 :
1147
+ (dimension ,) = grouper .group .dims
1124
1148
if dimension in self ._obj .dims :
1125
1149
axis = self ._obj .get_axis_num (dimension )
1126
1150
else :
1127
1151
axis = 1e6 # some arbitrarily high value
1128
1152
return axis
1129
1153
1130
1154
new_order = sorted (stacked .dims , key = lookup_order )
1131
- return stacked .transpose (* new_order , transpose_coords = self ._restore_coord_dims )
1155
+ stacked = stacked .transpose (
1156
+ * new_order , transpose_coords = self ._restore_coord_dims
1157
+ )
1158
+ return stacked
1132
1159
1133
1160
def map (
1134
1161
self ,
0 commit comments