Skip to content

Commit a46078d

Browse files
committed
SeriesGroupBy.unique to get blocks for each label.
1 parent 5c4c7a3 commit a46078d

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

flox/core.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,11 @@ def find_group_cohorts(labels, chunks, merge: bool = True):
186186
blocks[idx] = np.full(tuple(block.shape[ax] for ax in axis), idx)
187187
which_chunk = np.block(blocks.reshape(shape).tolist()).reshape(-1)
188188

189-
# We always drop NaN; np.unique also considers every NaN to be different so
190-
# it's really important we get rid of them.
191189
raveled = labels.reshape(-1)
192-
unique_labels = np.unique(raveled[~isnull(raveled)])
193190
# these are chunks where a label is present
194-
label_chunks = {lab: tuple(np.unique(which_chunk[raveled == lab])) for lab in unique_labels}
191+
label_chunks = pd.Series(which_chunk).groupby(raveled).unique()
195192
# These invert the label_chunks mapping so we know which labels occur together.
196-
chunks_cohorts = tlz.groupby(label_chunks.get, label_chunks.keys())
193+
chunks_cohorts = tlz.groupby(lambda x: tuple(label_chunks.get(x)), label_chunks.keys())
197194

198195
if merge:
199196
# First sort by number of chunks occupied by cohort

0 commit comments

Comments
 (0)