Skip to content

Commit 046cf68

Browse files
authored
Many optimizations (#120)
* Skip factorizing with RangeIndex fastpath for binning by multiple variables. * Workaround pandas-dev/pandas#47614 * Avoid dispatching to pandas searchsorted * Remove unused variable. * ravel to reshape(-1) * Revert "Avoid dispatching to pandas searchsorted" This reverts commit 9aab6a4c194fa5c14b1f28ccb89dc7f8f8ebaa7d.
1 parent 61b134b commit 046cf68

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

flox/core.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _get_expected_groups(by, sort, *, raise_if_dask=True) -> pd.Index | None:
6464
if raise_if_dask:
6565
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
6666
return None
67-
flatby = by.ravel()
67+
flatby = by.reshape(-1)
6868
expected = pd.unique(flatby[~isnull(flatby)])
6969
return _convert_expected_groups_to_index((expected,), isbin=(False,), sort=sort)[0]
7070

@@ -175,11 +175,11 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
175175
blocks = np.empty(np.prod(shape), dtype=object)
176176
for idx, block in enumerate(array.blocks.ravel()):
177177
blocks[idx] = np.full(tuple(block.shape[ax] for ax in axis), idx)
178-
which_chunk = np.block(blocks.reshape(shape).tolist()).ravel()
178+
which_chunk = np.block(blocks.reshape(shape).tolist()).reshape(-1)
179179

180180
# We always drop NaN; np.unique also considers every NaN to be different so
181181
# it's really important we get rid of them.
182-
raveled = labels.ravel()
182+
raveled = labels.reshape(-1)
183183
unique_labels = np.unique(raveled[~isnull(raveled)])
184184
# these are chunks where a label is present
185185
label_chunks = {lab: tuple(np.unique(which_chunk[raveled == lab])) for lab in unique_labels}
@@ -421,19 +421,28 @@ def factorize_(
421421
factorized = []
422422
found_groups = []
423423
for groupvar, expect in zip(by, expected_groups):
424-
flat = groupvar.ravel()
425-
if isinstance(expect, pd.IntervalIndex):
424+
flat = groupvar.reshape(-1)
425+
if isinstance(expect, pd.RangeIndex):
426+
idx = flat
427+
found_groups.append(np.array(expect))
428+
# TODO: fix by using masked integers
429+
idx[idx > expect[-1]] = -1
430+
431+
elif isinstance(expect, pd.IntervalIndex):
426432
# when binning we change expected groups to integers marking the interval
427433
# this makes the reindexing logic simpler.
428-
if expect is None:
429-
raise ValueError("Please pass bin edges in expected_groups.")
430-
# TODO: fix for binning
431-
found_groups.append(expect)
432-
# pd.cut with bins = IntervalIndex[datetime64] doesn't work...
434+
# workaround for https://github.com/pandas-dev/pandas/issues/47614
435+
# we create breaks and pass that to pd.cut, disallow closed="both" for now.
436+
if expect.closed == "both":
437+
raise NotImplementedError
433438
if groupvar.dtype.kind == "M":
434-
expect = np.concatenate([expect.left.to_numpy(), [expect.right[-1].to_numpy()]])
439+
# pd.cut with bins = IntervalIndex[datetime64] doesn't work...
440+
bins = np.concatenate([expect.left.to_numpy(), [expect.right[-1].to_numpy()]])
441+
else:
442+
bins = np.concatenate([expect.left.to_numpy(), [expect.right[-1]]])
435443
# code is -1 for values outside the bounds of all intervals
436-
idx = pd.cut(flat, bins=expect).codes.copy()
444+
idx = pd.cut(flat, bins=bins, right=expect.closed_right).codes.copy()
445+
found_groups.append(expect)
437446
else:
438447
if expect is not None and reindex:
439448
sorter = np.argsort(expect)
@@ -447,7 +456,7 @@ def factorize_(
447456
idx = sorter[(idx,)]
448457
idx[mask] = -1
449458
else:
450-
idx, groups = pd.factorize(groupvar.ravel(), sort=sort)
459+
idx, groups = pd.factorize(flat, sort=sort)
451460

452461
found_groups.append(np.array(groups))
453462
factorized.append(idx)
@@ -473,7 +482,7 @@ def factorize_(
473482
# we collapse to a 2D by and axis=-1
474483
offset_group = True
475484
group_idx, size = offset_labels(group_idx.reshape(by[0].shape), ngroups)
476-
group_idx = group_idx.ravel()
485+
group_idx = group_idx.reshape(-1)
477486
else:
478487
size = ngroups
479488
offset_group = False
@@ -622,7 +631,7 @@ def chunk_reduce(
622631
# avoid by factorizing again so indices=[2,2,2] is changed to
623632
# indices=[0,0,0]. This is necessary when combining block results
624633
# factorize can handle strings etc unlike digitize
625-
group_idx, groups, found_groups_shape, ngroups, size, props = factorize_(
634+
group_idx, groups, found_groups_shape, _, size, props = factorize_(
626635
(by,), axis, expected_groups=(expected_groups,), reindex=reindex, sort=sort
627636
)
628637
groups = groups[0]

0 commit comments

Comments
 (0)