Skip to content

Commit 0bf35e0

Browse files
authored
Performance improvements for cohorts detection (#172)
* Faster unique * SeriesGroupBy.unique to get blocks for each label. * Add cohorts benchmark * Add empty cache if cachey is absent * Always initialize cache * Guard cache clear * Add numtasks benchmark * Update asv_bench/benchmarks/cohorts.py
1 parent 878e284 commit 0bf35e0

File tree

4 files changed

+74
-9
lines changed

4 files changed

+74
-9
lines changed

asv_bench/benchmarks/cohorts.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import dask
2+
import numpy as np
3+
import pandas as pd
4+
5+
import flox
6+
7+
8+
class Cohorts:
9+
"""Time the core reduction function."""
10+
11+
def setup(self, *args, **kwargs):
12+
raise NotImplementedError
13+
14+
def time_find_group_cohorts(self):
15+
flox.core.find_group_cohorts(self.by, self.array.chunks)
16+
# The cache clear fails dependably in CI
17+
# Not sure why
18+
try:
19+
flox.cache.cache.clear()
20+
except AttributeError:
21+
pass
22+
23+
def time_graph_construct(self):
24+
flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis, method="cohorts")
25+
26+
def track_num_tasks(self):
27+
result = flox.groupby_reduce(
28+
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
29+
)[0]
30+
return len(result.dask.to_dict())
31+
32+
track_num_tasks.unit = "tasks"
33+
34+
35+
class NWMMidwest(Cohorts):
36+
"""2D labels, ireregular w.r.t chunk size.
37+
Mimics National Weather Model, Midwest county groupby."""
38+
39+
def setup(self, *args, **kwargs):
40+
x = np.repeat(np.arange(30), 150)
41+
y = np.repeat(np.arange(30), 60)
42+
self.by = x[np.newaxis, :] * y[:, np.newaxis]
43+
44+
self.array = dask.array.ones(self.by.shape, chunks=(350, 350))
45+
self.axis = (-2, -1)
46+
47+
48+
class ERA5(Cohorts):
49+
"""ERA5"""
50+
51+
def setup(self, *args, **kwargs):
52+
time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))
53+
54+
self.by = time.dt.dayofyear.values
55+
self.axis = (-1,)
56+
57+
array = dask.array.random.random((721, 1440, len(time)), chunks=(-1, -1, 48))
58+
self.array = flox.core.rechunk_for_cohorts(
59+
array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True
60+
)

flox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
# flake8: noqa
33
"""Top-level module for flox ."""
4+
from . import cache
45
from .aggregations import Aggregation # noqa
56
from .core import groupby_reduce, rechunk_for_blockwise, rechunk_for_cohorts # noqa
67

flox/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
cache = cachey.Cache(1e6)
99
memoize = partial(cache.memoize, key=dask.base.tokenize)
1010
except ImportError:
11+
cache = {}
1112
memoize = lambda x: x # type: ignore

flox/core.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _collapse_axis(arr: np.ndarray, naxis: int) -> np.ndarray:
106106
def _get_optimal_chunks_for_groups(chunks, labels):
107107
chunkidx = np.cumsum(chunks) - 1
108108
# what are the groups at chunk boundaries
109-
labels_at_chunk_bounds = np.unique(labels[chunkidx])
109+
labels_at_chunk_bounds = _unique(labels[chunkidx])
110110
# what's the last index of all groups
111111
last_indexes = npg.aggregate_numpy.aggregate(labels, np.arange(len(labels)), func="last")
112112
# what's the last index of groups at the chunk boundaries.
@@ -136,6 +136,12 @@ def _get_optimal_chunks_for_groups(chunks, labels):
136136
return tuple(newchunks)
137137

138138

139+
def _unique(a):
140+
"""Much faster to use pandas unique and sort the results.
141+
np.unique sorts before uniquifying and is slow."""
142+
return np.sort(pd.unique(a))
143+
144+
139145
@memoize
140146
def find_group_cohorts(labels, chunks, merge: bool = True):
141147
"""
@@ -180,14 +186,11 @@ def find_group_cohorts(labels, chunks, merge: bool = True):
180186
blocks[idx] = np.full(tuple(block.shape[ax] for ax in axis), idx)
181187
which_chunk = np.block(blocks.reshape(shape).tolist()).reshape(-1)
182188

183-
# We always drop NaN; np.unique also considers every NaN to be different so
184-
# it's really important we get rid of them.
185189
raveled = labels.reshape(-1)
186-
unique_labels = np.unique(raveled[~isnull(raveled)])
187190
# these are chunks where a label is present
188-
label_chunks = {lab: tuple(np.unique(which_chunk[raveled == lab])) for lab in unique_labels}
191+
label_chunks = pd.Series(which_chunk).groupby(raveled).unique()
189192
# These invert the label_chunks mapping so we know which labels occur together.
190-
chunks_cohorts = tlz.groupby(label_chunks.get, label_chunks.keys())
193+
chunks_cohorts = tlz.groupby(lambda x: tuple(label_chunks.get(x)), label_chunks.keys())
191194

192195
if merge:
193196
# First sort by number of chunks occupied by cohort
@@ -892,7 +895,7 @@ def _grouped_combine(
892895
# when there's only a single axis of reduction, we can just concatenate later,
893896
# reindexing is unnecessary
894897
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
895-
unique_groups = np.unique(tuple(flatten(deepmap(listify_groups, x_chunk))))
898+
unique_groups = _unique(tuple(flatten(deepmap(listify_groups, x_chunk))))
896899
unique_groups = unique_groups[~isnull(unique_groups)]
897900
if len(unique_groups) == 0:
898901
unique_groups = [np.nan]
@@ -1065,7 +1068,7 @@ def subset_to_blocks(
10651068
unraveled = np.unravel_index(flatblocks, blkshape)
10661069
normalized: list[Union[int, np.ndarray, slice]] = []
10671070
for ax, idx in enumerate(unraveled):
1068-
i = np.unique(idx).squeeze()
1071+
i = _unique(idx).squeeze()
10691072
if i.ndim == 0:
10701073
normalized.append(i.item())
10711074
else:
@@ -1310,7 +1313,7 @@ def dask_groupby_agg(
13101313
# along the reduced axis
13111314
slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
13121315
if expected_groups is None:
1313-
groups_in_block = tuple(np.unique(by_input[slc]) for slc in slices)
1316+
groups_in_block = tuple(_unique(by_input[slc]) for slc in slices)
13141317
else:
13151318
# For cohorts, we could be indexing a block with groups that
13161319
# are not in the cohort (usually for nD `by`)

0 commit comments

Comments
 (0)