Skip to content

Commit 878e284

Browse files
authored
Remove split_out (#170)
1 parent 03107af commit 878e284

File tree

3 files changed

+19
-84
lines changed

3 files changed

+19
-84
lines changed

flox/core.py

Lines changed: 18 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -736,13 +736,6 @@ def _squeeze_results(results: IntermediateDict, axis: T_Axes) -> IntermediateDic
736736
return newresults
737737

738738

739-
def _split_groups(array, j, slicer):
740-
"""Slices out chunks when split_out > 1"""
741-
results = {"groups": array["groups"][..., slicer]}
742-
results["intermediates"] = [v[..., slicer] for v in array["intermediates"]]
743-
return results
744-
745-
746739
def _finalize_results(
747740
results: IntermediateDict,
748741
agg: Aggregation,
@@ -997,38 +990,6 @@ def _grouped_combine(
997990
return results
998991

999992

1000-
def split_blocks(applied, split_out, expected_groups, split_name):
1001-
import dask.array
1002-
from dask.array.core import normalize_chunks
1003-
from dask.highlevelgraph import HighLevelGraph
1004-
1005-
chunk_tuples = tuple(itertools.product(*tuple(range(n) for n in applied.numblocks)))
1006-
ngroups = len(expected_groups)
1007-
group_chunks = normalize_chunks(np.ceil(ngroups / split_out), (ngroups,))
1008-
idx = tuple(np.cumsum((0,) + group_chunks[0]))
1009-
1010-
# split each block into `split_out` chunks
1011-
dsk = {}
1012-
for i in chunk_tuples:
1013-
for j in range(split_out):
1014-
dsk[(split_name, *i, j)] = (
1015-
_split_groups,
1016-
(applied.name, *i),
1017-
j,
1018-
slice(idx[j], idx[j + 1]),
1019-
)
1020-
1021-
# now construct an array that can be passed to _tree_reduce
1022-
intergraph = HighLevelGraph.from_collections(split_name, dsk, dependencies=(applied,))
1023-
intermediate = dask.array.Array(
1024-
intergraph,
1025-
name=split_name,
1026-
chunks=applied.chunks + ((1,) * split_out,),
1027-
meta=applied._meta,
1028-
)
1029-
return intermediate, group_chunks
1030-
1031-
1032993
def _reduce_blockwise(
1033994
array,
1034995
by,
@@ -1169,7 +1130,6 @@ def dask_groupby_agg(
11691130
agg: Aggregation,
11701131
expected_groups: pd.Index | None,
11711132
axis: T_Axes = (),
1172-
split_out: int = 1,
11731133
fill_value: Any = None,
11741134
method: T_Method = "map-reduce",
11751135
reindex: bool = False,
@@ -1186,19 +1146,14 @@ def dask_groupby_agg(
11861146
assert isinstance(axis, Sequence)
11871147
assert all(ax >= 0 for ax in axis)
11881148

1189-
if method == "blockwise" and (split_out > 1 or not isinstance(by, np.ndarray)):
1190-
raise NotImplementedError
1191-
1192-
if split_out > 1 and expected_groups is None:
1193-
# This could be implemented using the "hash_split" strategy
1194-
# from dask.dataframe
1149+
if method == "blockwise" and not isinstance(by, np.ndarray):
11951150
raise NotImplementedError
11961151

11971152
inds = tuple(range(array.ndim))
11981153
name = f"groupby_{agg.name}"
1199-
token = dask.base.tokenize(array, by, agg, expected_groups, axis, split_out)
1154+
token = dask.base.tokenize(array, by, agg, expected_groups, axis)
12001155

1201-
if expected_groups is None and (reindex or split_out > 1):
1156+
if expected_groups is None and reindex:
12021157
expected_groups = _get_expected_groups(by, sort=sort)
12031158

12041159
by_input = by
@@ -1229,9 +1184,7 @@ def dask_groupby_agg(
12291184
# This allows us to discover groups at compute time, support argreductions, lower intermediate
12301185
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
12311186

1232-
do_simple_combine = (
1233-
method != "blockwise" and reindex and not _is_arg_reduction(agg) and split_out == 1
1234-
)
1187+
do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction(agg)
12351188
if method == "blockwise":
12361189
# use the "non dask" code path, but applied blockwise
12371190
blockwise_method = partial(
@@ -1244,14 +1197,14 @@ def dask_groupby_agg(
12441197
func=agg.chunk,
12451198
fill_value=agg.fill_value["intermediate"],
12461199
dtype=agg.dtype["intermediate"],
1247-
reindex=reindex or (split_out > 1),
1200+
reindex=reindex,
12481201
)
12491202
if do_simple_combine:
12501203
# Add a dummy dimension that then gets reduced over
12511204
blockwise_method = tlz.compose(_expand_dims, blockwise_method)
12521205

12531206
# apply reduction on chunk
1254-
applied = dask.array.blockwise(
1207+
intermediate = dask.array.blockwise(
12551208
partial(
12561209
blockwise_method,
12571210
axis=axis,
@@ -1271,18 +1224,14 @@ def dask_groupby_agg(
12711224
token=f"{name}-chunk-{token}",
12721225
)
12731226

1274-
if split_out > 1:
1275-
intermediate, group_chunks = split_blocks(
1276-
applied, split_out, expected_groups, split_name=f"{name}-split-{token}"
1277-
)
1278-
else:
1279-
intermediate = applied
1280-
if expected_groups is None:
1281-
if is_duck_dask_array(by_input):
1282-
expected_groups = None
1283-
else:
1284-
expected_groups = _get_expected_groups(by_input, sort=sort)
1285-
group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)
1227+
if expected_groups is None:
1228+
if is_duck_dask_array(by_input):
1229+
expected_groups = None
1230+
else:
1231+
expected_groups = _get_expected_groups(by_input, sort=sort)
1232+
group_chunks: tuple[tuple[Union[int, float], ...]] = (
1233+
(len(expected_groups),) if expected_groups is not None else (np.nan,),
1234+
)
12861235

12871236
if method in ["map-reduce", "cohorts", "split-reduce"]:
12881237
combine: Callable[..., IntermediateDict]
@@ -1311,9 +1260,7 @@ def dask_groupby_agg(
13111260
if method == "map-reduce":
13121261
reduced = tree_reduce(
13131262
intermediate,
1314-
aggregate=partial(
1315-
aggregate, expected_groups=None if split_out > 1 else expected_groups
1316-
),
1263+
aggregate=partial(aggregate, expected_groups=expected_groups),
13171264
)
13181265
if is_duck_dask_array(by_input) and expected_groups is None:
13191266
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
@@ -1380,7 +1327,7 @@ def dask_groupby_agg(
13801327
raise ValueError(f"Unknown method={method}.")
13811328

13821329
# extract results from the dict
1383-
output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + group_chunks
1330+
output_chunks = reduced.chunks[: -len(axis)] + group_chunks
13841331
ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks)
13851332
layer2: dict[tuple, tuple] = {}
13861333
agg_name = f"{name}-{token}"
@@ -1392,10 +1339,7 @@ def dask_groupby_agg(
13921339
nblocks = tuple(len(array.chunks[ax]) for ax in axis)
13931340
inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
13941341
else:
1395-
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1)
1396-
if split_out > 1:
1397-
inchunk = inchunk + (0,)
1398-
inchunk = inchunk + (ochunk[-1],)
1342+
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + (ochunk[-1],)
13991343

14001344
layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)
14011345

@@ -1516,7 +1460,6 @@ def groupby_reduce(
15161460
fill_value=None,
15171461
dtype: np.typing.DTypeLike = None,
15181462
min_count: int | None = None,
1519-
split_out: int = 1,
15201463
method: T_Method = "map-reduce",
15211464
engine: T_Engine = "numpy",
15221465
reindex: bool | None = None,
@@ -1555,8 +1498,6 @@ def groupby_reduce(
15551498
fewer than min_count non-NA values are present the result will be
15561499
NA. Only used if skipna is set to True or defaults to True for the
15571500
array's dtype.
1558-
split_out : int, optional
1559-
Number of chunks along group axis in output (last axis)
15601501
method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional
15611502
Strategy for reduction of dask arrays only:
15621503
* ``"map-reduce"``:
@@ -1750,7 +1691,7 @@ def groupby_reduce(
17501691
if kwargs["fill_value"] is None:
17511692
kwargs["fill_value"] = agg.fill_value[agg.name]
17521693

1753-
partial_agg = partial(dask_groupby_agg, split_out=split_out, **kwargs)
1694+
partial_agg = partial(dask_groupby_agg, **kwargs)
17541695

17551696
if method == "blockwise" and by_.ndim == 1:
17561697
array = rechunk_for_blockwise(array, axis=-1, labels=by_)

flox/xarray.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def xarray_reduce(
6262
isbin: bool | Sequence[bool] = False,
6363
sort: bool = True,
6464
dim: Dims | ellipsis = None,
65-
split_out: int = 1,
6665
fill_value=None,
6766
dtype: np.typing.DTypeLike = None,
6867
method: str = "map-reduce",
@@ -95,8 +94,6 @@ def xarray_reduce(
9594
dim : hashable
9695
dimension name along which to reduce. If None, reduces across all
9796
dimensions of `by`
98-
split_out : int, optional
99-
Number of output chunks along grouped dimension in output.
10097
fill_value
10198
Value used for missing groups in the output i.e. when one of the labels
10299
in ``expected_groups`` is not actually present in ``by``.
@@ -397,7 +394,6 @@ def wrapper(array, *by, func, skipna, **kwargs):
397394
"func": func,
398395
"axis": axis,
399396
"sort": sort,
400-
"split_out": split_out,
401397
"fill_value": fill_value,
402398
"method": method,
403399
"min_count": min_count,

tests/test_core.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_alignment_error():
7979

8080

8181
@pytest.mark.parametrize("dtype", (float, int))
82-
@pytest.mark.parametrize("chunk, split_out", [(False, 1), (True, 1), (True, 2), (True, 3)])
82+
@pytest.mark.parametrize("chunk", [False, True])
8383
@pytest.mark.parametrize("expected_groups", [None, [0, 1, 2], np.array([0, 1, 2])])
8484
@pytest.mark.parametrize(
8585
"func, array, by, expected",
@@ -114,7 +114,6 @@ def test_groupby_reduce(
114114
expected: list[float],
115115
expected_groups: T_ExpectedGroupsOpt,
116116
chunk: bool,
117-
split_out: int,
118117
dtype: np.typing.DTypeLike,
119118
) -> None:
120119
array = array.astype(dtype)
@@ -137,7 +136,6 @@ def test_groupby_reduce(
137136
func=func,
138137
expected_groups=expected_groups,
139138
fill_value=123,
140-
split_out=split_out,
141139
engine=engine,
142140
)
143141
g_dtype = by.dtype if expected_groups is None else np.asarray(expected_groups).dtype

0 commit comments

Comments
 (0)