@@ -106,7 +106,7 @@ def _collapse_axis(arr: np.ndarray, naxis: int) -> np.ndarray:
106
106
def _get_optimal_chunks_for_groups (chunks , labels ):
107
107
chunkidx = np .cumsum (chunks ) - 1
108
108
# what are the groups at chunk boundaries
109
- labels_at_chunk_bounds = np . unique (labels [chunkidx ])
109
+ labels_at_chunk_bounds = _unique (labels [chunkidx ])
110
110
# what's the last index of all groups
111
111
last_indexes = npg .aggregate_numpy .aggregate (labels , np .arange (len (labels )), func = "last" )
112
112
# what's the last index of groups at the chunk boundaries.
@@ -136,6 +136,12 @@ def _get_optimal_chunks_for_groups(chunks, labels):
136
136
return tuple (newchunks )
137
137
138
138
139
+ def _unique (a ):
140
+ """Much faster to use pandas unique and sort the results.
141
+ np.unique sorts before uniquifying and is slow."""
142
+ return np .sort (pd .unique (a ))
143
+
144
+
139
145
@memoize
140
146
def find_group_cohorts (labels , chunks , merge : bool = True ):
141
147
"""
@@ -180,14 +186,11 @@ def find_group_cohorts(labels, chunks, merge: bool = True):
180
186
blocks [idx ] = np .full (tuple (block .shape [ax ] for ax in axis ), idx )
181
187
which_chunk = np .block (blocks .reshape (shape ).tolist ()).reshape (- 1 )
182
188
183
- # We always drop NaN; np.unique also considers every NaN to be different so
184
- # it's really important we get rid of them.
185
189
raveled = labels .reshape (- 1 )
186
- unique_labels = np .unique (raveled [~ isnull (raveled )])
187
190
# these are chunks where a label is present
188
- label_chunks = { lab : tuple ( np . unique (which_chunk [ raveled == lab ])) for lab in unique_labels }
191
+ label_chunks = pd . Series (which_chunk ). groupby ( raveled ). unique ()
189
192
# These invert the label_chunks mapping so we know which labels occur together.
190
- chunks_cohorts = tlz .groupby (label_chunks .get , label_chunks .keys ())
193
+ chunks_cohorts = tlz .groupby (lambda x : tuple ( label_chunks .get ( x )) , label_chunks .keys ())
191
194
192
195
if merge :
193
196
# First sort by number of chunks occupied by cohort
@@ -892,7 +895,7 @@ def _grouped_combine(
892
895
# when there's only a single axis of reduction, we can just concatenate later,
893
896
# reindexing is unnecessary
894
897
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
895
- unique_groups = np . unique (tuple (flatten (deepmap (listify_groups , x_chunk ))))
898
+ unique_groups = _unique (tuple (flatten (deepmap (listify_groups , x_chunk ))))
896
899
unique_groups = unique_groups [~ isnull (unique_groups )]
897
900
if len (unique_groups ) == 0 :
898
901
unique_groups = [np .nan ]
@@ -1065,7 +1068,7 @@ def subset_to_blocks(
1065
1068
unraveled = np .unravel_index (flatblocks , blkshape )
1066
1069
normalized : list [Union [int , np .ndarray , slice ]] = []
1067
1070
for ax , idx in enumerate (unraveled ):
1068
- i = np . unique (idx ).squeeze ()
1071
+ i = _unique (idx ).squeeze ()
1069
1072
if i .ndim == 0 :
1070
1073
normalized .append (i .item ())
1071
1074
else :
@@ -1310,7 +1313,7 @@ def dask_groupby_agg(
1310
1313
# along the reduced axis
1311
1314
slices = slices_from_chunks (tuple (array .chunks [ax ] for ax in axis ))
1312
1315
if expected_groups is None :
1313
- groups_in_block = tuple (np . unique (by_input [slc ]) for slc in slices )
1316
+ groups_in_block = tuple (_unique (by_input [slc ]) for slc in slices )
1314
1317
else :
1315
1318
# For cohorts, we could be indexing a block with groups that
1316
1319
# are not in the cohort (usually for nD `by`)
0 commit comments