@@ -137,7 +137,7 @@ def _get_optimal_chunks_for_groups(chunks, labels):
137
137
138
138
139
139
@memoize
140
- def find_group_cohorts (labels , chunks , merge = True , method : T_MethodCohorts = "cohorts" ):
140
+ def find_group_cohorts (labels , chunks , merge : bool = True ):
141
141
"""
142
142
Finds groups labels that occur together aka "cohorts"
143
143
@@ -167,9 +167,6 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co
167
167
# To do this, we must have values in memory so casting to numpy should be safe
168
168
labels = np .asarray (labels )
169
169
170
- if method == "split-reduce" :
171
- return list (_get_expected_groups (labels , sort = False ).to_numpy ().reshape (- 1 , 1 ))
172
-
173
170
# Build an array with the shape of labels, but where every element is the "chunk number"
174
171
# 1. First subset the array appropriately
175
172
axis = range (- labels .ndim , 0 )
@@ -195,7 +192,7 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co
195
192
if merge :
196
193
# First sort by number of chunks occupied by cohort
197
194
sorted_chunks_cohorts = dict (
198
- reversed ( sorted (chunks_cohorts .items (), key = lambda kv : len (kv [0 ])) )
195
+ sorted (chunks_cohorts .items (), key = lambda kv : len (kv [0 ]), reverse = True )
199
196
)
200
197
201
198
items = tuple (sorted_chunks_cohorts .items ())
@@ -218,9 +215,15 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co
218
215
merged_cohorts [k1 ].extend (v2 )
219
216
merged_keys .append (k2 )
220
217
221
- return merged_cohorts .values ()
218
+ # make sure each cohort is sorted after merging
219
+ sorted_merged_cohorts = {k : sorted (v ) for k , v in merged_cohorts .items ()}
220
+ # sort by first label in cohort
221
+ # This will help when sort=True (default)
222
+ # and we have to resort the dask array
223
+ return dict (sorted (sorted_merged_cohorts .items (), key = lambda kv : kv [1 ][0 ]))
224
+
222
225
else :
223
- return chunks_cohorts . values ()
226
+ return chunks_cohorts
224
227
225
228
226
229
def rechunk_for_cohorts (
@@ -1079,6 +1082,63 @@ def _reduce_blockwise(
1079
1082
return result
1080
1083
1081
1084
1085
+ def subset_to_blocks (
1086
+ array : DaskArray , flatblocks : Sequence [int ], blkshape : tuple [int ] | None = None
1087
+ ) -> DaskArray :
1088
+ """
1089
+ Advanced indexing of .blocks such that we always get a regular array back.
1090
+
1091
+ Parameters
1092
+ ----------
1093
+ array : dask.array
1094
+ flatblocks : flat indices of blocks to extract
1095
+ blkshape : shape of blocks with which to unravel flatblocks
1096
+
1097
+ Returns
1098
+ -------
1099
+ dask.array
1100
+ """
1101
+ if blkshape is None :
1102
+ blkshape = array .blocks .shape
1103
+
1104
+ unraveled = np .unravel_index (flatblocks , blkshape )
1105
+ normalized : list [Union [int , np .ndarray , slice ]] = []
1106
+ for ax , idx in enumerate (unraveled ):
1107
+ i = np .unique (idx ).squeeze ()
1108
+ if i .ndim == 0 :
1109
+ normalized .append (i .item ())
1110
+ else :
1111
+ if np .array_equal (i , np .arange (blkshape [ax ])):
1112
+ normalized .append (slice (None ))
1113
+ elif np .array_equal (i , np .arange (i [0 ], i [- 1 ] + 1 )):
1114
+ normalized .append (slice (i [0 ], i [- 1 ] + 1 ))
1115
+ else :
1116
+ normalized .append (i )
1117
+ full_normalized = (slice (None ),) * (array .ndim - len (normalized )) + tuple (normalized )
1118
+
1119
+ # has no iterables
1120
+ noiter = tuple (i if not hasattr (i , "__len__" ) else slice (None ) for i in full_normalized )
1121
+ # has all iterables
1122
+ alliter = {
1123
+ ax : i if hasattr (i , "__len__" ) else slice (None ) for ax , i in enumerate (full_normalized )
1124
+ }
1125
+
1126
+ # apply everything but the iterables
1127
+ if all (i == slice (None ) for i in noiter ):
1128
+ return array
1129
+
1130
+ subset = array .blocks [noiter ]
1131
+
1132
+ for ax , inds in alliter .items ():
1133
+ if isinstance (inds , slice ):
1134
+ continue
1135
+ idxr = [slice (None , None )] * array .ndim
1136
+ idxr [ax ] = inds
1137
+ subset = subset .blocks [tuple (idxr )]
1138
+
1139
+ return subset
1140
+
1141
+
1082
1142
def _extract_unknown_groups (reduced , group_chunks , dtype ) -> tuple [DaskArray ]:
1083
1143
import dask .array
1084
1144
from dask .highlevelgraph import HighLevelGraph
@@ -1115,6 +1175,7 @@ def dask_groupby_agg(
1115
1175
reindex : bool = False ,
1116
1176
engine : T_Engine = "numpy" ,
1117
1177
sort : bool = True ,
1178
+ chunks_cohorts = None ,
1118
1179
) -> tuple [DaskArray , tuple [np .ndarray | DaskArray ]]:
1119
1180
1120
1181
import dask .array
@@ -1194,7 +1255,7 @@ def dask_groupby_agg(
1194
1255
partial (
1195
1256
blockwise_method ,
1196
1257
axis = axis ,
1197
- expected_groups = expected_groups ,
1258
+ expected_groups = None if method in [ "split-reduce" , "cohorts" ] else expected_groups ,
1198
1259
engine = engine ,
1199
1260
sort = sort ,
1200
1261
),
@@ -1223,43 +1284,77 @@ def dask_groupby_agg(
1223
1284
expected_groups = _get_expected_groups (by_input , sort = sort )
1224
1285
group_chunks = ((len (expected_groups ),) if expected_groups is not None else (np .nan ,),)
1225
1286
1226
- if method == "map-reduce" :
1287
+ if method in [ "map-reduce" , "cohorts" , "split-reduce" ] :
1227
1288
combine : Callable [..., IntermediateDict ]
1228
1289
if do_simple_combine :
1229
1290
combine = _simple_combine
1230
1291
else :
1231
1292
combine = partial (_grouped_combine , engine = engine , sort = sort )
1232
1293
1233
- # reduced is really a dict mapping reduction name to array
1234
- # and "groups" to an array of group labels
1294
+ # Each chunk of `reduced`` is really a dict mapping
1295
+ # 1. reduction name to array
1296
+ # 2. "groups" to an array of group labels
1235
1297
# Note: it does not make sense to interpret axis relative to
1236
1298
# shape of intermediate results after the blockwise call
1237
- reduced = dask .array .reductions ._tree_reduce (
1238
- intermediate ,
1239
- aggregate = partial (
1240
- _aggregate ,
1241
- combine = combine ,
1242
- agg = agg ,
1243
- expected_groups = None if split_out > 1 else expected_groups ,
1244
- fill_value = fill_value ,
1245
- reindex = reindex ,
1246
- ),
1299
+ tree_reduce = partial (
1300
+ dask .array .reductions ._tree_reduce ,
1247
1301
combine = partial (combine , agg = agg ),
1248
- name = f"{ name } -reduce" ,
1302
+ name = f"{ name } -reduce- { method } " ,
1249
1303
dtype = array .dtype ,
1250
1304
axis = axis ,
1251
1305
keepdims = True ,
1252
1306
concatenate = False ,
1253
1307
)
1254
-
1255
- if is_duck_dask_array (by_input ) and expected_groups is None :
1256
- groups = _extract_unknown_groups (reduced , group_chunks = group_chunks , dtype = by .dtype )
1257
- else :
1258
- if expected_groups is None :
1259
- expected_groups_ = _get_expected_groups (by_input , sort = sort )
1308
+ aggregate = partial (
1309
+ _aggregate , combine = combine , agg = agg , fill_value = fill_value , reindex = reindex
1310
+ )
1311
+ if method == "map-reduce" :
1312
+ reduced = tree_reduce (
1313
+ intermediate ,
1314
+ aggregate = partial (
1315
+ aggregate , expected_groups = None if split_out > 1 else expected_groups
1316
+ ),
1317
+ )
1318
+ if is_duck_dask_array (by_input ) and expected_groups is None :
1319
+ groups = _extract_unknown_groups (reduced , group_chunks = group_chunks , dtype = by .dtype )
1260
1320
else :
1261
- expected_groups_ = expected_groups
1262
- groups = (expected_groups_ .to_numpy (),)
1321
+ if expected_groups is None :
1322
+ expected_groups_ = _get_expected_groups (by_input , sort = sort )
1323
+ else :
1324
+ expected_groups_ = expected_groups
1325
+ groups = (expected_groups_ .to_numpy (),)
1326
+
1327
+ elif method in ["cohorts" , "split-reduce" ]:
1328
+ chunks_cohorts = find_group_cohorts (
1329
+ by_input , [array .chunks [ax ] for ax in axis ], merge = True
1330
+ )
1331
+ reduced_ = []
1332
+ groups_ = []
1333
+ for blks , cohort in chunks_cohorts .items ():
1334
+ subset = subset_to_blocks (intermediate , blks , array .blocks .shape [- len (axis ) :])
1335
+ if do_simple_combine :
1336
+ # reindex so that reindex can be set to True later
1337
+ reindexed = dask .array .map_blocks (
1338
+ reindex_intermediates ,
1339
+ subset ,
1340
+ agg = agg ,
1341
+ unique_groups = cohort ,
1342
+ meta = subset ._meta ,
1343
+ )
1344
+ else :
1345
+ reindexed = subset
1346
+
1347
+ reduced_ .append (
1348
+ tree_reduce (
1349
+ reindexed ,
1350
+ aggregate = partial (aggregate , expected_groups = cohort , reindex = reindex ),
1351
+ )
1352
+ )
1353
+ groups_ .append (cohort )
1354
+
1355
+ reduced = dask .array .concatenate (reduced_ , axis = - 1 )
1356
+ groups = (np .concatenate (groups_ ),)
1357
+ group_chunks = (tuple (len (cohort ) for cohort in groups_ ),)
1263
1358
1264
1359
elif method == "blockwise" :
1265
1360
reduced = intermediate
@@ -1297,7 +1392,11 @@ def dask_groupby_agg(
1297
1392
nblocks = tuple (len (array .chunks [ax ]) for ax in axis )
1298
1393
inchunk = ochunk [:- 1 ] + np .unravel_index (ochunk [- 1 ], nblocks )
1299
1394
else :
1300
- inchunk = ochunk [:- 1 ] + (0 ,) * len (axis ) + (ochunk [- 1 ],) * int (split_out > 1 )
1395
+ inchunk = ochunk [:- 1 ] + (0 ,) * (len (axis ) - 1 )
1396
+ if split_out > 1 :
1397
+ inchunk = inchunk + (0 ,)
1398
+ inchunk = inchunk + (ochunk [- 1 ],)
1399
+
1301
1400
layer2 [(agg_name , * ochunk )] = (operator .getitem , (reduced .name , * inchunk ), agg .name )
1302
1401
1303
1402
result = dask .array .Array (
@@ -1326,6 +1425,9 @@ def _validate_reindex(reindex: bool | None, func, method: T_Method, expected_gro
1326
1425
if method in ["split-reduce" , "cohorts" ] and reindex is False :
1327
1426
raise NotImplementedError
1328
1427
1428
+ if method in ["split-reduce" , "cohorts" ] and reindex is None :
1429
+ reindex = True
1430
+
1329
1431
# TODO: Should reindex be a bool-only at this point? Would've been nice but
1330
1432
# None's are relied on after this function as well.
1331
1433
return reindex
@@ -1480,9 +1582,7 @@ def groupby_reduce(
1480
1582
method by first rechunking using ``rechunk_for_cohorts``
1481
1583
(for 1D ``by`` only).
1482
1584
* ``"split-reduce"``:
1483
- Break out each group into its own array and then ``"map-reduce"``.
1484
- This is implemented by having each group be its own cohort,
1485
- and is identical to xarray's default strategy.
1585
+ Same as "cohorts" and will be removed soon.
1486
1586
engine : {"flox", "numpy", "numba"}, optional
1487
1587
Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
1488
1588
* ``"numpy"``:
@@ -1652,67 +1752,26 @@ def groupby_reduce(
1652
1752
1653
1753
partial_agg = partial (dask_groupby_agg , split_out = split_out , ** kwargs )
1654
1754
1655
- if method in ["split-reduce" , "cohorts" ]:
1656
- cohorts = find_group_cohorts (
1657
- by_ , [array .chunks [ax ] for ax in axis_ ], merge = True , method = method
1658
- )
1659
-
1660
- results_ = []
1661
- groups_ = []
1662
- for cohort in cohorts :
1663
- cohort = sorted (cohort )
1664
- # equivalent of xarray.DataArray.where(mask, drop=True)
1665
- mask = np .isin (by_ , cohort )
1666
- indexer = [np .unique (v ) for v in np .nonzero (mask )]
1667
- array_subset = array
1668
- for ax , idxr in zip (range (- by_ .ndim , 0 ), indexer ):
1669
- array_subset = np .take (array_subset , idxr , axis = ax )
1670
- numblocks = math .prod ([len (array_subset .chunks [ax ]) for ax in axis_ ])
1671
-
1672
- # get final result for these groups
1673
- r , * g = partial_agg (
1674
- array_subset ,
1675
- by_ [np .ix_ (* indexer )],
1676
- expected_groups = pd .Index (cohort ),
1677
- # First deep copy becasue we might be doping blockwise,
1678
- # which sets agg.finalize=None, then map-reduce (GH102)
1679
- agg = copy .deepcopy (agg ),
1680
- # reindex to expected_groups at the blockwise step.
1681
- # this approach avoids replacing non-cohort members with
1682
- # np.nan or some other sentinel value, and preserves dtypes
1683
- reindex = True ,
1684
- # sort controls the final output order so apply that at the end
1685
- sort = False ,
1686
- # if only a single block along axis, we can just work blockwise
1687
- # inspired by https://github.com/dask/dask/issues/8361
1688
- method = "blockwise" if numblocks == 1 and nax == by_ .ndim else "map-reduce" ,
1689
- )
1690
- results_ .append (r )
1691
- groups_ .append (cohort )
1755
+ if method == "blockwise" and by_ .ndim == 1 :
1756
+ array = rechunk_for_blockwise (array , axis = - 1 , labels = by_ )
1692
1757
1693
- # concatenate results together,
1694
- # sort to make sure we match expected output
1695
- groups = (np .hstack (groups_ ),)
1696
- result = np .concatenate (results_ , axis = - 1 )
1697
- else :
1698
- if method == "blockwise" and by_ .ndim == 1 :
1699
- array = rechunk_for_blockwise (array , axis = - 1 , labels = by_ )
1700
-
1701
- result , groups = partial_agg (
1702
- array ,
1703
- by_ ,
1704
- expected_groups = None if method == "blockwise" else expected_groups ,
1705
- agg = agg ,
1706
- reindex = reindex ,
1707
- method = method ,
1708
- sort = sort ,
1709
- )
1758
+ result , groups = partial_agg (
1759
+ array ,
1760
+ by_ ,
1761
+ expected_groups = None if method == "blockwise" else expected_groups ,
1762
+ agg = agg ,
1763
+ reindex = reindex ,
1764
+ method = method ,
1765
+ sort = sort ,
1766
+ )
1710
1767
1711
1768
if sort and method != "map-reduce" :
1712
1769
assert len (groups ) == 1
1713
1770
sorted_idx = np .argsort (groups [0 ])
1714
- result = result [..., sorted_idx ]
1715
- groups = (groups [0 ][sorted_idx ],)
1771
+ # This optimization helps specifically with resampling
1772
+ if not (sorted_idx [1 :] <= sorted_idx [:- 1 ]).all ():
1773
+ result = result [..., sorted_idx ]
1774
+ groups = (groups [0 ][sorted_idx ],)
1716
1775
1717
1776
if factorize_early :
1718
1777
# nan group labels are factorized to -1, and preserved
0 commit comments