Skip to content

Commit 72dfc87

Browse files
authored
More efficient cohorts. (#165)
Closes #140 We apply the cohort "split" step after the blockwise reduction, then use the tree reduction on each cohort. We also use the `.blocks` accessor to index out blocks. This is still a bit inefficient since we split by indexing out regular arrays, so we could index out blocks that don't contain any cohort members. However, because we are splitting after the blockwise reduction, the amount of work duplication can be a lot less than splitting the bare array. One side-effect is that "split-reduce" is now a synonym for "cohorts". I see no benefit to having a separate code path. We also sort the cohorts at detection time to minimize shuffling of the final output when `sort=True` (the default). Finally, we avoid sorting if the groups are already sorted.
1 parent 05fe726 commit 72dfc87

File tree

5 files changed

+158
-107
lines changed

5 files changed

+158
-107
lines changed

docs/source/implementation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ or `xarray_reduce`.
1313

1414
First we describe xarray's current strategy
1515

16-
## `method="split-reduce"`: Xarray's current GroupBy strategy
16+
## Background: Xarray's current GroupBy strategy
1717

1818
Xarray's current strategy is to find all unique group labels, index out each group,
1919
and then apply the reduction operation. Note that this only works if we know the group

flox/core.py

Lines changed: 149 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _get_optimal_chunks_for_groups(chunks, labels):
137137

138138

139139
@memoize
140-
def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "cohorts"):
140+
def find_group_cohorts(labels, chunks, merge: bool = True):
141141
"""
142142
Finds groups labels that occur together aka "cohorts"
143143
@@ -167,9 +167,6 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co
167167
# To do this, we must have values in memory so casting to numpy should be safe
168168
labels = np.asarray(labels)
169169

170-
if method == "split-reduce":
171-
return list(_get_expected_groups(labels, sort=False).to_numpy().reshape(-1, 1))
172-
173170
# Build an array with the shape of labels, but where every element is the "chunk number"
174171
# 1. First subset the array appropriately
175172
axis = range(-labels.ndim, 0)
@@ -195,7 +192,7 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co
195192
if merge:
196193
# First sort by number of chunks occupied by cohort
197194
sorted_chunks_cohorts = dict(
198-
reversed(sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0])))
195+
sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True)
199196
)
200197

201198
items = tuple(sorted_chunks_cohorts.items())
@@ -218,9 +215,15 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co
218215
merged_cohorts[k1].extend(v2)
219216
merged_keys.append(k2)
220217

221-
return merged_cohorts.values()
218+
# make sure each cohort is sorted after merging
219+
sorted_merged_cohorts = {k: sorted(v) for k, v in merged_cohorts.items()}
220+
# sort by first label in cohort
221+
# This will help when sort=True (default)
222+
# and we have to resort the dask array
223+
return dict(sorted(sorted_merged_cohorts.items(), key=lambda kv: kv[1][0]))
224+
222225
else:
223-
return chunks_cohorts.values()
226+
return chunks_cohorts
224227

225228

226229
def rechunk_for_cohorts(
@@ -1079,6 +1082,63 @@ def _reduce_blockwise(
10791082
return result
10801083

10811084

1085+
def subset_to_blocks(
1086+
array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
1087+
) -> DaskArray:
1088+
"""
1089+
Advanced indexing of .blocks such that we always get a regular array back.
1090+
1091+
Parameters
1092+
----------
1093+
array : dask.array
1094+
flatblocks : flat indices of blocks to extract
1095+
blkshape : shape of blocks with which to unravel flatblocks
1096+
1097+
Returns
1098+
-------
1099+
dask.array
1100+
"""
1101+
if blkshape is None:
1102+
blkshape = array.blocks.shape
1103+
1104+
unraveled = np.unravel_index(flatblocks, blkshape)
1105+
normalized: list[Union[int, np.ndarray, slice]] = []
1106+
for ax, idx in enumerate(unraveled):
1107+
i = np.unique(idx).squeeze()
1108+
if i.ndim == 0:
1109+
normalized.append(i.item())
1110+
else:
1111+
if np.array_equal(i, np.arange(blkshape[ax])):
1112+
normalized.append(slice(None))
1113+
elif np.array_equal(i, np.arange(i[0], i[-1] + 1)):
1114+
normalized.append(slice(i[0], i[-1] + 1))
1115+
else:
1116+
normalized.append(i)
1117+
full_normalized = (slice(None),) * (array.ndim - len(normalized)) + tuple(normalized)
1118+
1119+
# has no iterables
1120+
noiter = tuple(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized)
1121+
# has all iterables
1122+
alliter = {
1123+
ax: i if hasattr(i, "__len__") else slice(None) for ax, i in enumerate(full_normalized)
1124+
}
1125+
1126+
# apply everything but the iterables
1127+
if all(i == slice(None) for i in noiter):
1128+
return array
1129+
1130+
subset = array.blocks[noiter]
1131+
1132+
for ax, inds in alliter.items():
1133+
if isinstance(inds, slice):
1134+
continue
1135+
idxr = [slice(None, None)] * array.ndim
1136+
idxr[ax] = inds
1137+
subset = subset.blocks[tuple(idxr)]
1138+
1139+
return subset
1140+
1141+
10821142
def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]:
10831143
import dask.array
10841144
from dask.highlevelgraph import HighLevelGraph
@@ -1115,6 +1175,7 @@ def dask_groupby_agg(
11151175
reindex: bool = False,
11161176
engine: T_Engine = "numpy",
11171177
sort: bool = True,
1178+
chunks_cohorts=None,
11181179
) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]:
11191180

11201181
import dask.array
@@ -1194,7 +1255,7 @@ def dask_groupby_agg(
11941255
partial(
11951256
blockwise_method,
11961257
axis=axis,
1197-
expected_groups=expected_groups,
1258+
expected_groups=None if method in ["split-reduce", "cohorts"] else expected_groups,
11981259
engine=engine,
11991260
sort=sort,
12001261
),
@@ -1223,43 +1284,77 @@ def dask_groupby_agg(
12231284
expected_groups = _get_expected_groups(by_input, sort=sort)
12241285
group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)
12251286

1226-
if method == "map-reduce":
1287+
if method in ["map-reduce", "cohorts", "split-reduce"]:
12271288
combine: Callable[..., IntermediateDict]
12281289
if do_simple_combine:
12291290
combine = _simple_combine
12301291
else:
12311292
combine = partial(_grouped_combine, engine=engine, sort=sort)
12321293

1233-
# reduced is really a dict mapping reduction name to array
1234-
# and "groups" to an array of group labels
1294+
# Each chunk of `reduced`` is really a dict mapping
1295+
# 1. reduction name to array
1296+
# 2. "groups" to an array of group labels
12351297
# Note: it does not make sense to interpret axis relative to
12361298
# shape of intermediate results after the blockwise call
1237-
reduced = dask.array.reductions._tree_reduce(
1238-
intermediate,
1239-
aggregate=partial(
1240-
_aggregate,
1241-
combine=combine,
1242-
agg=agg,
1243-
expected_groups=None if split_out > 1 else expected_groups,
1244-
fill_value=fill_value,
1245-
reindex=reindex,
1246-
),
1299+
tree_reduce = partial(
1300+
dask.array.reductions._tree_reduce,
12471301
combine=partial(combine, agg=agg),
1248-
name=f"{name}-reduce",
1302+
name=f"{name}-reduce-{method}",
12491303
dtype=array.dtype,
12501304
axis=axis,
12511305
keepdims=True,
12521306
concatenate=False,
12531307
)
1254-
1255-
if is_duck_dask_array(by_input) and expected_groups is None:
1256-
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
1257-
else:
1258-
if expected_groups is None:
1259-
expected_groups_ = _get_expected_groups(by_input, sort=sort)
1308+
aggregate = partial(
1309+
_aggregate, combine=combine, agg=agg, fill_value=fill_value, reindex=reindex
1310+
)
1311+
if method == "map-reduce":
1312+
reduced = tree_reduce(
1313+
intermediate,
1314+
aggregate=partial(
1315+
aggregate, expected_groups=None if split_out > 1 else expected_groups
1316+
),
1317+
)
1318+
if is_duck_dask_array(by_input) and expected_groups is None:
1319+
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
12601320
else:
1261-
expected_groups_ = expected_groups
1262-
groups = (expected_groups_.to_numpy(),)
1321+
if expected_groups is None:
1322+
expected_groups_ = _get_expected_groups(by_input, sort=sort)
1323+
else:
1324+
expected_groups_ = expected_groups
1325+
groups = (expected_groups_.to_numpy(),)
1326+
1327+
elif method in ["cohorts", "split-reduce"]:
1328+
chunks_cohorts = find_group_cohorts(
1329+
by_input, [array.chunks[ax] for ax in axis], merge=True
1330+
)
1331+
reduced_ = []
1332+
groups_ = []
1333+
for blks, cohort in chunks_cohorts.items():
1334+
subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :])
1335+
if do_simple_combine:
1336+
# reindex so that reindex can be set to True later
1337+
reindexed = dask.array.map_blocks(
1338+
reindex_intermediates,
1339+
subset,
1340+
agg=agg,
1341+
unique_groups=cohort,
1342+
meta=subset._meta,
1343+
)
1344+
else:
1345+
reindexed = subset
1346+
1347+
reduced_.append(
1348+
tree_reduce(
1349+
reindexed,
1350+
aggregate=partial(aggregate, expected_groups=cohort, reindex=reindex),
1351+
)
1352+
)
1353+
groups_.append(cohort)
1354+
1355+
reduced = dask.array.concatenate(reduced_, axis=-1)
1356+
groups = (np.concatenate(groups_),)
1357+
group_chunks = (tuple(len(cohort) for cohort in groups_),)
12631358

12641359
elif method == "blockwise":
12651360
reduced = intermediate
@@ -1297,7 +1392,11 @@ def dask_groupby_agg(
12971392
nblocks = tuple(len(array.chunks[ax]) for ax in axis)
12981393
inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
12991394
else:
1300-
inchunk = ochunk[:-1] + (0,) * len(axis) + (ochunk[-1],) * int(split_out > 1)
1395+
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1)
1396+
if split_out > 1:
1397+
inchunk = inchunk + (0,)
1398+
inchunk = inchunk + (ochunk[-1],)
1399+
13011400
layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)
13021401

13031402
result = dask.array.Array(
@@ -1326,6 +1425,9 @@ def _validate_reindex(reindex: bool | None, func, method: T_Method, expected_gro
13261425
if method in ["split-reduce", "cohorts"] and reindex is False:
13271426
raise NotImplementedError
13281427

1428+
if method in ["split-reduce", "cohorts"] and reindex is None:
1429+
reindex = True
1430+
13291431
# TODO: Should reindex be a bool-only at this point? Would've been nice but
13301432
# None's are relied on after this function as well.
13311433
return reindex
@@ -1480,9 +1582,7 @@ def groupby_reduce(
14801582
method by first rechunking using ``rechunk_for_cohorts``
14811583
(for 1D ``by`` only).
14821584
* ``"split-reduce"``:
1483-
Break out each group into its own array and then ``"map-reduce"``.
1484-
This is implemented by having each group be its own cohort,
1485-
and is identical to xarray's default strategy.
1585+
Same as "cohorts" and will be removed soon.
14861586
engine : {"flox", "numpy", "numba"}, optional
14871587
Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
14881588
* ``"numpy"``:
@@ -1652,67 +1752,26 @@ def groupby_reduce(
16521752

16531753
partial_agg = partial(dask_groupby_agg, split_out=split_out, **kwargs)
16541754

1655-
if method in ["split-reduce", "cohorts"]:
1656-
cohorts = find_group_cohorts(
1657-
by_, [array.chunks[ax] for ax in axis_], merge=True, method=method
1658-
)
1659-
1660-
results_ = []
1661-
groups_ = []
1662-
for cohort in cohorts:
1663-
cohort = sorted(cohort)
1664-
# equivalent of xarray.DataArray.where(mask, drop=True)
1665-
mask = np.isin(by_, cohort)
1666-
indexer = [np.unique(v) for v in np.nonzero(mask)]
1667-
array_subset = array
1668-
for ax, idxr in zip(range(-by_.ndim, 0), indexer):
1669-
array_subset = np.take(array_subset, idxr, axis=ax)
1670-
numblocks = math.prod([len(array_subset.chunks[ax]) for ax in axis_])
1671-
1672-
# get final result for these groups
1673-
r, *g = partial_agg(
1674-
array_subset,
1675-
by_[np.ix_(*indexer)],
1676-
expected_groups=pd.Index(cohort),
1677-
# First deep copy becasue we might be doping blockwise,
1678-
# which sets agg.finalize=None, then map-reduce (GH102)
1679-
agg=copy.deepcopy(agg),
1680-
# reindex to expected_groups at the blockwise step.
1681-
# this approach avoids replacing non-cohort members with
1682-
# np.nan or some other sentinel value, and preserves dtypes
1683-
reindex=True,
1684-
# sort controls the final output order so apply that at the end
1685-
sort=False,
1686-
# if only a single block along axis, we can just work blockwise
1687-
# inspired by https://github.com/dask/dask/issues/8361
1688-
method="blockwise" if numblocks == 1 and nax == by_.ndim else "map-reduce",
1689-
)
1690-
results_.append(r)
1691-
groups_.append(cohort)
1755+
if method == "blockwise" and by_.ndim == 1:
1756+
array = rechunk_for_blockwise(array, axis=-1, labels=by_)
16921757

1693-
# concatenate results together,
1694-
# sort to make sure we match expected output
1695-
groups = (np.hstack(groups_),)
1696-
result = np.concatenate(results_, axis=-1)
1697-
else:
1698-
if method == "blockwise" and by_.ndim == 1:
1699-
array = rechunk_for_blockwise(array, axis=-1, labels=by_)
1700-
1701-
result, groups = partial_agg(
1702-
array,
1703-
by_,
1704-
expected_groups=None if method == "blockwise" else expected_groups,
1705-
agg=agg,
1706-
reindex=reindex,
1707-
method=method,
1708-
sort=sort,
1709-
)
1758+
result, groups = partial_agg(
1759+
array,
1760+
by_,
1761+
expected_groups=None if method == "blockwise" else expected_groups,
1762+
agg=agg,
1763+
reindex=reindex,
1764+
method=method,
1765+
sort=sort,
1766+
)
17101767

17111768
if sort and method != "map-reduce":
17121769
assert len(groups) == 1
17131770
sorted_idx = np.argsort(groups[0])
1714-
result = result[..., sorted_idx]
1715-
groups = (groups[0][sorted_idx],)
1771+
# This optimization helps specifically with resampling
1772+
if not (sorted_idx[1:] <= sorted_idx[:-1]).all():
1773+
result = result[..., sorted_idx]
1774+
groups = (groups[0][sorted_idx],)
17161775

17171776
if factorize_early:
17181777
# nan group labels are factorized to -1, and preserved

flox/visualize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ def visualize_cohorts_2d(by, array, method="cohorts"):
136136
print("finding cohorts...")
137137
before_merged = find_group_cohorts(
138138
by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=False, method=method
139-
)
139+
).values()
140140
merged = find_group_cohorts(
141141
by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=True, method=method
142-
)
142+
).values()
143143
print("finished cohorts...")
144144

145145
xticks = np.cumsum(array.chunks[-1])

flox/xarray.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,7 @@ def xarray_reduce(
126126
method by first rechunking using ``rechunk_for_cohorts``
127127
(for 1D ``by`` only).
128128
* ``"split-reduce"``:
129-
Break out each group into its own array and then ``"map-reduce"``.
130-
This is implemented by having each group be its own cohort,
131-
and is identical to xarray's default strategy.
129+
Same as "cohorts" and will be removed soon.
132130
engine : {"flox", "numpy", "numba"}, optional
133131
Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
134132
* ``"numpy"``:

0 commit comments

Comments
 (0)