Skip to content

Commit 394925e

Browse files
committed
Support reindexing in simple_combine
For 1D combine, great improvement for cohorts-type reductions More memory but similar time for map-reduce. Note that the map-reduce intermediates are a worst case where there are no shared groups between the chunks being combined. This case is actually optimized in _group_combine where reindexing is skipped for reducing along a single axis. [ 68.75%] ··· =========== ========= ========= -- combine ----------- ------------------- kind grouped combine =========== ========= ========= cohorts 760M 631M mapreduce 981M 1.81G =========== ========= ========= [ 75.00%] ··· =========== ========== =========== -- combine ----------- ---------------------- kind grouped combine =========== ========== =========== cohorts 393±10ms 137±10ms mapreduce 652±10ms 611±400ms =========== ========== =========== Fix bug in unique
1 parent c370b5d commit 394925e

File tree

3 files changed

+118
-85
lines changed

3 files changed

+118
-85
lines changed

asv_bench/benchmarks/combine.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import numpy as np
24

35
import flox
@@ -7,26 +9,31 @@
79
N = 1000
810

911

12+
def _get_combine(combine):
13+
if combine == "grouped":
14+
return partial(flox.core._grouped_combine, engine="numpy")
15+
else:
16+
return partial(flox.core._simple_combine, reindex=False)
17+
18+
1019
class Combine:
1120
def setup(self, *args, **kwargs):
1221
raise NotImplementedError
1322

14-
@parameterized("kind", ("cohorts", "mapreduce"))
15-
def time_combine(self, kind):
16-
flox.core._grouped_combine(
23+
@parameterized(("kind", "combine"), (("reindexed", "not_reindexed"), ("grouped", "simple")))
24+
def time_combine(self, kind, combine):
25+
_get_combine(combine)(
1726
getattr(self, f"x_chunk_{kind}"),
1827
**self.kwargs,
1928
keepdims=True,
20-
engine="numpy",
2129
)
2230

23-
@parameterized("kind", ("cohorts", "mapreduce"))
24-
def peakmem_combine(self, kind):
25-
flox.core._grouped_combine(
31+
@parameterized(("kind", "combine"), (("reindexed", "not_reindexed"), ("grouped", "simple")))
32+
def peakmem_combine(self, kind, combine):
33+
_get_combine(combine)(
2634
getattr(self, f"x_chunk_{kind}"),
2735
**self.kwargs,
2836
keepdims=True,
29-
engine="numpy",
3037
)
3138

3239

@@ -47,7 +54,7 @@ def construct_member(groups):
4754
}
4855

4956
# motivated by
50-
self.x_chunk_mapreduce = [
57+
self.x_chunk_not_reindexed = [
5158
construct_member(groups)
5259
for groups in [
5360
np.array((1, 2, 3, 4)),
@@ -57,5 +64,7 @@ def construct_member(groups):
5764
* 2
5865
]
5966

60-
self.x_chunk_cohorts = [construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4]
67+
self.x_chunk_reindexed = [
68+
construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4
69+
]
6170
self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,)}

flox/core.py

Lines changed: 59 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -816,8 +816,25 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict:
816816
return results
817817

818818

819+
def _find_unique_groups(x_chunk):
820+
from dask.base import flatten
821+
from dask.utils import deepmap
822+
823+
unique_groups = _unique(np.asarray(tuple(flatten(deepmap(listify_groups, x_chunk)))))
824+
unique_groups = unique_groups[~isnull(unique_groups)]
825+
826+
if len(unique_groups) == 0:
827+
unique_groups = [np.nan]
828+
return unique_groups
829+
830+
819831
def _simple_combine(
820-
x_chunk, agg: Aggregation, axis: T_Axes, keepdims: bool, is_aggregate: bool = False
832+
x_chunk,
833+
agg: Aggregation,
834+
axis: T_Axes,
835+
keepdims: bool,
836+
reindex: bool,
837+
is_aggregate: bool = False,
821838
) -> IntermediateDict:
822839
"""
823840
'Simple' combination of blockwise results.
@@ -830,8 +847,19 @@ def _simple_combine(
830847
4. At the final agggregate step, we squeeze out DUMMY_AXIS
831848
"""
832849
from dask.array.core import deepfirst
850+
from dask.utils import deepmap
851+
852+
if not reindex:
853+
# We didn't reindex at the blockwise step
854+
# So now reindex before combining by reducing along DUMMY_AXIS
855+
unique_groups = _find_unique_groups(x_chunk)
856+
x_chunk = deepmap(
857+
partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
858+
)
859+
else:
860+
unique_groups = deepfirst(x_chunk)["groups"]
833861

834-
results: IntermediateDict = {"groups": deepfirst(x_chunk)["groups"]}
862+
results: IntermediateDict = {"groups": unique_groups}
835863
results["intermediates"] = []
836864
axis_ = axis[:-1] + (DUMMY_AXIS,)
837865
for idx, combine in enumerate(agg.combine):
@@ -886,7 +914,6 @@ def _grouped_combine(
886914
sort: bool = True,
887915
) -> IntermediateDict:
888916
"""Combine intermediates step of tree reduction."""
889-
from dask.base import flatten
890917
from dask.utils import deepmap
891918

892919
if isinstance(x_chunk, dict):
@@ -897,11 +924,7 @@ def _grouped_combine(
897924
# when there's only a single axis of reduction, we can just concatenate later,
898925
# reindexing is unnecessary
899926
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
900-
unique_groups = _unique(np.array(tuple(flatten(deepmap(listify_groups, x_chunk)))))
901-
unique_groups = unique_groups[~isnull(unique_groups)]
902-
if len(unique_groups) == 0:
903-
unique_groups = [np.nan]
904-
927+
unique_groups = _find_unique_groups(x_chunk)
905928
x_chunk = deepmap(
906929
partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
907930
)
@@ -1216,7 +1239,8 @@ def dask_groupby_agg(
12161239
# This allows us to discover groups at compute time, support argreductions, lower intermediate
12171240
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
12181241

1219-
do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction(agg)
1242+
do_simple_combine = not _is_arg_reduction(agg)
1243+
12201244
if method == "blockwise":
12211245
# use the "non dask" code path, but applied blockwise
12221246
blockwise_method = partial(
@@ -1268,31 +1292,32 @@ def dask_groupby_agg(
12681292
if method in ["map-reduce", "cohorts"]:
12691293
combine: Callable[..., IntermediateDict]
12701294
if do_simple_combine:
1271-
combine = _simple_combine
1295+
combine = partial(_simple_combine, reindex=reindex)
1296+
combine_name = "simple-combine"
12721297
else:
12731298
combine = partial(_grouped_combine, engine=engine, sort=sort)
1299+
combine_name = "grouped-combine"
12741300

1275-
# Each chunk of `reduced`` is really a dict mapping
1276-
# 1. reduction name to array
1277-
# 2. "groups" to an array of group labels
1278-
# Note: it does not make sense to interpret axis relative to
1279-
# shape of intermediate results after the blockwise call
12801301
tree_reduce = partial(
12811302
dask.array.reductions._tree_reduce,
1282-
combine=partial(combine, agg=agg),
1283-
name=f"{name}-reduce-{method}",
1303+
name=f"{name}-reduce-{method}-{combine_name}",
12841304
dtype=array.dtype,
12851305
axis=axis,
12861306
keepdims=True,
12871307
concatenate=False,
12881308
)
1289-
aggregate = partial(
1290-
_aggregate, combine=combine, agg=agg, fill_value=fill_value, reindex=reindex
1291-
)
1309+
aggregate = partial(_aggregate, combine=combine, agg=agg, fill_value=fill_value)
1310+
1311+
# Each chunk of `reduced`` is really a dict mapping
1312+
# 1. reduction name to array
1313+
# 2. "groups" to an array of group labels
1314+
# Note: it does not make sense to interpret axis relative to
1315+
# shape of intermediate results after the blockwise call
12921316
if method == "map-reduce":
12931317
reduced = tree_reduce(
12941318
intermediate,
1295-
aggregate=partial(aggregate, expected_groups=expected_groups),
1319+
combine=partial(combine, agg=agg),
1320+
aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex),
12961321
)
12971322
if is_duck_dask_array(by_input) and expected_groups is None:
12981323
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
@@ -1310,23 +1335,17 @@ def dask_groupby_agg(
13101335
reduced_ = []
13111336
groups_ = []
13121337
for blks, cohort in chunks_cohorts.items():
1338+
index = pd.Index(cohort)
13131339
subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :])
1314-
if do_simple_combine:
1315-
# reindex so that reindex can be set to True later
1316-
reindexed = dask.array.map_blocks(
1317-
reindex_intermediates,
1318-
subset,
1319-
agg=agg,
1320-
unique_groups=cohort,
1321-
meta=subset._meta,
1322-
)
1323-
else:
1324-
reindexed = subset
1325-
1340+
reindexed = dask.array.map_blocks(
1341+
reindex_intermediates, subset, agg=agg, unique_groups=index, meta=subset._meta
1342+
)
1343+
# now that we have reindexed, we can set reindex=True explicitlly
13261344
reduced_.append(
13271345
tree_reduce(
13281346
reindexed,
1329-
aggregate=partial(aggregate, expected_groups=cohort, reindex=reindex),
1347+
combine=partial(combine, agg=agg, reindex=True),
1348+
aggregate=partial(aggregate, expected_groups=index, reindex=True),
13301349
)
13311350
)
13321351
groups_.append(cohort)
@@ -1382,28 +1401,24 @@ def _validate_reindex(
13821401
if reindex is True:
13831402
if _is_arg_reduction(func):
13841403
raise NotImplementedError
1385-
if method == "blockwise":
1386-
raise NotImplementedError
1404+
if method in ["blockwise", "cohorts"]:
1405+
raise ValueError(
1406+
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
1407+
)
13871408

13881409
if reindex is None:
13891410
if method == "blockwise" or _is_arg_reduction(func):
13901411
reindex = False
13911412

1392-
elif expected_groups is not None:
1393-
reindex = True
1394-
1395-
elif method in ["split-reduce", "cohorts"]:
1396-
reindex = True
1413+
elif method == "cohorts":
1414+
reindex = False
13971415

13981416
elif method == "map-reduce":
13991417
if expected_groups is None and by_is_dask:
14001418
reindex = False
14011419
else:
14021420
reindex = True
14031421

1404-
if method in ["split-reduce", "cohorts"] and reindex is False:
1405-
raise NotImplementedError
1406-
14071422
assert isinstance(reindex, bool)
14081423
return reindex
14091424

tests/test_core.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import itertools
34
from functools import partial, reduce
45
from typing import TYPE_CHECKING
56

@@ -219,29 +220,31 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
219220
assert actual.dtype.kind == "i"
220221
assert_equal(actual, expected, tolerance)
221222

222-
if not has_dask:
223+
if not has_dask or chunks is None:
223224
continue
224-
for method in ["map-reduce", "cohorts", "split-reduce"]:
225-
if method == "map-reduce":
226-
reindexes = [True, False, None]
225+
226+
params = list(itertools.product(["map-reduce"], [True, False, None]))
227+
params.extend(itertools.product(["cohorts"], [False, None]))
228+
for method, reindex in params:
229+
call = partial(
230+
groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
231+
)
232+
if "arg" in func and reindex is True:
233+
# simple_combine with argreductions not supported right now
234+
with pytest.raises(NotImplementedError):
235+
call()
236+
continue
237+
actual, *groups = call()
238+
if "arg" not in func:
239+
# make sure we use simple combine
240+
assert any("simple-combine" in key for key in actual.dask.layers.keys())
227241
else:
228-
reindexes = [None]
229-
for reindex in reindexes:
230-
call = partial(
231-
groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
232-
)
233-
if "arg" in func:
234-
if method != "map-reduce" or reindex is True:
235-
with pytest.raises(NotImplementedError):
236-
call()
237-
continue
238-
239-
actual, *groups = call()
240-
for actual_group, expect in zip(groups, expected_groups):
241-
assert_equal(actual_group, expect, tolerance)
242-
if "arg" in func:
243-
assert actual.dtype.kind == "i"
244-
assert_equal(actual, expected, tolerance)
242+
assert any("grouped-combine" in key for key in actual.dask.layers.keys())
243+
for actual_group, expect in zip(groups, expected_groups):
244+
assert_equal(actual_group, expect, tolerance)
245+
if "arg" in func:
246+
assert actual.dtype.kind == "i"
247+
assert_equal(actual, expected, tolerance)
245248

246249

247250
@requires_dask
@@ -1140,7 +1143,6 @@ def test_subset_block_2d(flatblocks, expectidx):
11401143
assert_equal(subset, array.compute()[expectidx])
11411144

11421145

1143-
@pytest.mark.parametrize("method", ["map-reduce", "cohorts"])
11441146
@pytest.mark.parametrize(
11451147
"expected, reindex, func, expected_groups, by_is_dask",
11461148
[
@@ -1158,13 +1160,20 @@ def test_subset_block_2d(flatblocks, expectidx):
11581160
[True, None, "sum", ([1], None), True],
11591161
],
11601162
)
1161-
def test_validate_reindex(expected, reindex, func, method, expected_groups, by_is_dask):
1162-
if by_is_dask and method == "cohorts":
1163-
# This should error elsewhere
1164-
pytest.skip()
1165-
call = partial(_validate_reindex, reindex, func, method, expected_groups, by_is_dask)
1166-
if "arg" in func and method == "cohorts":
1163+
def test_validate_reindex_map_reduce(expected, reindex, func, expected_groups, by_is_dask):
1164+
actual = _validate_reindex(reindex, func, "map-reduce", expected_groups, by_is_dask)
1165+
assert actual == expected
1166+
1167+
1168+
def test_validate_reindex():
1169+
for method in ["map-reduce", "cohorts"]:
11671170
with pytest.raises(NotImplementedError):
1168-
call()
1169-
else:
1170-
assert call() == expected
1171+
_validate_reindex(True, "argmax", method, expected_groups=None, by_is_dask=False)
1172+
1173+
for method in ["blockwise", "cohorts"]:
1174+
with pytest.raises(ValueError):
1175+
_validate_reindex(True, "sum", method, expected_groups=None, by_is_dask=False)
1176+
1177+
for func in ["sum", "argmax"]:
1178+
actual = _validate_reindex(None, func, method, expected_groups=None, by_is_dask=False)
1179+
assert actual is False

0 commit comments

Comments
 (0)