@@ -736,13 +736,6 @@ def _squeeze_results(results: IntermediateDict, axis: T_Axes) -> IntermediateDic
736
736
return newresults
737
737
738
738
739
- def _split_groups (array , j , slicer ):
740
- """Slices out chunks when split_out > 1"""
741
- results = {"groups" : array ["groups" ][..., slicer ]}
742
- results ["intermediates" ] = [v [..., slicer ] for v in array ["intermediates" ]]
743
- return results
744
-
745
-
746
739
def _finalize_results (
747
740
results : IntermediateDict ,
748
741
agg : Aggregation ,
@@ -997,38 +990,6 @@ def _grouped_combine(
997
990
return results
998
991
999
992
1000
- def split_blocks (applied , split_out , expected_groups , split_name ):
1001
- import dask .array
1002
- from dask .array .core import normalize_chunks
1003
- from dask .highlevelgraph import HighLevelGraph
1004
-
1005
- chunk_tuples = tuple (itertools .product (* tuple (range (n ) for n in applied .numblocks )))
1006
- ngroups = len (expected_groups )
1007
- group_chunks = normalize_chunks (np .ceil (ngroups / split_out ), (ngroups ,))
1008
- idx = tuple (np .cumsum ((0 ,) + group_chunks [0 ]))
1009
-
1010
- # split each block into `split_out` chunks
1011
- dsk = {}
1012
- for i in chunk_tuples :
1013
- for j in range (split_out ):
1014
- dsk [(split_name , * i , j )] = (
1015
- _split_groups ,
1016
- (applied .name , * i ),
1017
- j ,
1018
- slice (idx [j ], idx [j + 1 ]),
1019
- )
1020
-
1021
- # now construct an array that can be passed to _tree_reduce
1022
- intergraph = HighLevelGraph .from_collections (split_name , dsk , dependencies = (applied ,))
1023
- intermediate = dask .array .Array (
1024
- intergraph ,
1025
- name = split_name ,
1026
- chunks = applied .chunks + ((1 ,) * split_out ,),
1027
- meta = applied ._meta ,
1028
- )
1029
- return intermediate , group_chunks
1030
-
1031
-
1032
993
def _reduce_blockwise (
1033
994
array ,
1034
995
by ,
@@ -1169,7 +1130,6 @@ def dask_groupby_agg(
1169
1130
agg : Aggregation ,
1170
1131
expected_groups : pd .Index | None ,
1171
1132
axis : T_Axes = (),
1172
- split_out : int = 1 ,
1173
1133
fill_value : Any = None ,
1174
1134
method : T_Method = "map-reduce" ,
1175
1135
reindex : bool = False ,
@@ -1186,19 +1146,14 @@ def dask_groupby_agg(
1186
1146
assert isinstance (axis , Sequence )
1187
1147
assert all (ax >= 0 for ax in axis )
1188
1148
1189
- if method == "blockwise" and (split_out > 1 or not isinstance (by , np .ndarray )):
1190
- raise NotImplementedError
1191
-
1192
- if split_out > 1 and expected_groups is None :
1193
- # This could be implemented using the "hash_split" strategy
1194
- # from dask.dataframe
1149
+ if method == "blockwise" and not isinstance (by , np .ndarray ):
1195
1150
raise NotImplementedError
1196
1151
1197
1152
inds = tuple (range (array .ndim ))
1198
1153
name = f"groupby_{ agg .name } "
1199
- token = dask .base .tokenize (array , by , agg , expected_groups , axis , split_out )
1154
+ token = dask .base .tokenize (array , by , agg , expected_groups , axis )
1200
1155
1201
- if expected_groups is None and ( reindex or split_out > 1 ) :
1156
+ if expected_groups is None and reindex :
1202
1157
expected_groups = _get_expected_groups (by , sort = sort )
1203
1158
1204
1159
by_input = by
@@ -1229,9 +1184,7 @@ def dask_groupby_agg(
1229
1184
# This allows us to discover groups at compute time, support argreductions, lower intermediate
1230
1185
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
1231
1186
1232
- do_simple_combine = (
1233
- method != "blockwise" and reindex and not _is_arg_reduction (agg ) and split_out == 1
1234
- )
1187
+ do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction (agg )
1235
1188
if method == "blockwise" :
1236
1189
# use the "non dask" code path, but applied blockwise
1237
1190
blockwise_method = partial (
@@ -1244,14 +1197,14 @@ def dask_groupby_agg(
1244
1197
func = agg .chunk ,
1245
1198
fill_value = agg .fill_value ["intermediate" ],
1246
1199
dtype = agg .dtype ["intermediate" ],
1247
- reindex = reindex or ( split_out > 1 ) ,
1200
+ reindex = reindex ,
1248
1201
)
1249
1202
if do_simple_combine :
1250
1203
# Add a dummy dimension that then gets reduced over
1251
1204
blockwise_method = tlz .compose (_expand_dims , blockwise_method )
1252
1205
1253
1206
# apply reduction on chunk
1254
- applied = dask .array .blockwise (
1207
+ intermediate = dask .array .blockwise (
1255
1208
partial (
1256
1209
blockwise_method ,
1257
1210
axis = axis ,
@@ -1271,18 +1224,14 @@ def dask_groupby_agg(
1271
1224
token = f"{ name } -chunk-{ token } " ,
1272
1225
)
1273
1226
1274
- if split_out > 1 :
1275
- intermediate , group_chunks = split_blocks (
1276
- applied , split_out , expected_groups , split_name = f"{ name } -split-{ token } "
1277
- )
1278
- else :
1279
- intermediate = applied
1280
- if expected_groups is None :
1281
- if is_duck_dask_array (by_input ):
1282
- expected_groups = None
1283
- else :
1284
- expected_groups = _get_expected_groups (by_input , sort = sort )
1285
- group_chunks = ((len (expected_groups ),) if expected_groups is not None else (np .nan ,),)
1227
+ if expected_groups is None :
1228
+ if is_duck_dask_array (by_input ):
1229
+ expected_groups = None
1230
+ else :
1231
+ expected_groups = _get_expected_groups (by_input , sort = sort )
1232
+ group_chunks : tuple [tuple [Union [int , float ], ...]] = (
1233
+ (len (expected_groups ),) if expected_groups is not None else (np .nan ,),
1234
+ )
1286
1235
1287
1236
if method in ["map-reduce" , "cohorts" , "split-reduce" ]:
1288
1237
combine : Callable [..., IntermediateDict ]
@@ -1311,9 +1260,7 @@ def dask_groupby_agg(
1311
1260
if method == "map-reduce" :
1312
1261
reduced = tree_reduce (
1313
1262
intermediate ,
1314
- aggregate = partial (
1315
- aggregate , expected_groups = None if split_out > 1 else expected_groups
1316
- ),
1263
+ aggregate = partial (aggregate , expected_groups = expected_groups ),
1317
1264
)
1318
1265
if is_duck_dask_array (by_input ) and expected_groups is None :
1319
1266
groups = _extract_unknown_groups (reduced , group_chunks = group_chunks , dtype = by .dtype )
@@ -1380,7 +1327,7 @@ def dask_groupby_agg(
1380
1327
raise ValueError (f"Unknown method={ method } ." )
1381
1328
1382
1329
# extract results from the dict
1383
- output_chunks = reduced .chunks [: - ( len (axis ) + int ( split_out > 1 ) )] + group_chunks
1330
+ output_chunks = reduced .chunks [: - len (axis )] + group_chunks
1384
1331
ochunks = tuple (range (len (chunks_v )) for chunks_v in output_chunks )
1385
1332
layer2 : dict [tuple , tuple ] = {}
1386
1333
agg_name = f"{ name } -{ token } "
@@ -1392,10 +1339,7 @@ def dask_groupby_agg(
1392
1339
nblocks = tuple (len (array .chunks [ax ]) for ax in axis )
1393
1340
inchunk = ochunk [:- 1 ] + np .unravel_index (ochunk [- 1 ], nblocks )
1394
1341
else :
1395
- inchunk = ochunk [:- 1 ] + (0 ,) * (len (axis ) - 1 )
1396
- if split_out > 1 :
1397
- inchunk = inchunk + (0 ,)
1398
- inchunk = inchunk + (ochunk [- 1 ],)
1342
+ inchunk = ochunk [:- 1 ] + (0 ,) * (len (axis ) - 1 ) + (ochunk [- 1 ],)
1399
1343
1400
1344
layer2 [(agg_name , * ochunk )] = (operator .getitem , (reduced .name , * inchunk ), agg .name )
1401
1345
@@ -1516,7 +1460,6 @@ def groupby_reduce(
1516
1460
fill_value = None ,
1517
1461
dtype : np .typing .DTypeLike = None ,
1518
1462
min_count : int | None = None ,
1519
- split_out : int = 1 ,
1520
1463
method : T_Method = "map-reduce" ,
1521
1464
engine : T_Engine = "numpy" ,
1522
1465
reindex : bool | None = None ,
@@ -1555,8 +1498,6 @@ def groupby_reduce(
1555
1498
fewer than min_count non-NA values are present the result will be
1556
1499
NA. Only used if skipna is set to True or defaults to True for the
1557
1500
array's dtype.
1558
- split_out : int, optional
1559
- Number of chunks along group axis in output (last axis)
1560
1501
method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional
1561
1502
Strategy for reduction of dask arrays only:
1562
1503
* ``"map-reduce"``:
@@ -1750,7 +1691,7 @@ def groupby_reduce(
1750
1691
if kwargs ["fill_value" ] is None :
1751
1692
kwargs ["fill_value" ] = agg .fill_value [agg .name ]
1752
1693
1753
- partial_agg = partial (dask_groupby_agg , split_out = split_out , ** kwargs )
1694
+ partial_agg = partial (dask_groupby_agg , ** kwargs )
1754
1695
1755
1696
if method == "blockwise" and by_ .ndim == 1 :
1756
1697
array = rechunk_for_blockwise (array , axis = - 1 , labels = by_ )
0 commit comments