@@ -803,10 +803,15 @@ def _aggregate(
803
803
keepdims ,
804
804
fill_value : Any ,
805
805
reindex : bool ,
806
+ return_array : bool ,
806
807
) -> FinalResultsDict :
807
808
"""Final aggregation step of tree reduction"""
808
809
results = combine (x_chunk , agg , axis , keepdims , is_aggregate = True )
809
- return _finalize_results (results , agg , axis , expected_groups , fill_value , reindex )
810
+ finalized = _finalize_results (results , agg , axis , expected_groups , fill_value , reindex )
811
+ if return_array :
812
+ return finalized [agg .name ]
813
+ else :
814
+ return finalized
810
815
811
816
812
817
def _expand_dims (results : IntermediateDict ) -> IntermediateDict :
@@ -1287,6 +1292,7 @@ def dask_groupby_agg(
1287
1292
group_chunks : tuple [tuple [Union [int , float ], ...]] = (
1288
1293
(len (expected_groups ),) if expected_groups is not None else (np .nan ,),
1289
1294
)
1295
+ groups_are_unknown = is_duck_dask_array (by_input ) and expected_groups is None
1290
1296
1291
1297
if method in ["map-reduce" , "cohorts" ]:
1292
1298
combine : Callable [..., IntermediateDict ]
@@ -1316,16 +1322,32 @@ def dask_groupby_agg(
1316
1322
reduced = tree_reduce (
1317
1323
intermediate ,
1318
1324
combine = partial (combine , agg = agg ),
1319
- aggregate = partial (aggregate , expected_groups = expected_groups , reindex = reindex ),
1325
+ aggregate = partial (
1326
+ aggregate ,
1327
+ expected_groups = expected_groups ,
1328
+ reindex = reindex ,
1329
+ return_array = not groups_are_unknown ,
1330
+ ),
1320
1331
)
1321
- if is_duck_dask_array ( by_input ) and expected_groups is None :
1332
+ if groups_are_unknown :
1322
1333
groups = _extract_unknown_groups (reduced , group_chunks = group_chunks , dtype = by .dtype )
1334
+ result = dask .array .map_blocks (
1335
+ _extract_result ,
1336
+ reduced ,
1337
+ chunks = reduced .chunks [: - len (axis )] + group_chunks ,
1338
+ drop_axis = axis [:- 1 ],
1339
+ dtype = agg .dtype [agg .name ],
1340
+ key = agg .name ,
1341
+ name = f"{ name } -{ token } " ,
1342
+ )
1343
+
1323
1344
else :
1324
1345
if expected_groups is None :
1325
1346
expected_groups_ = _get_expected_groups (by_input , sort = sort )
1326
1347
else :
1327
1348
expected_groups_ = expected_groups
1328
1349
groups = (expected_groups_ .to_numpy (),)
1350
+ result = reduced
1329
1351
1330
1352
elif method == "cohorts" :
1331
1353
chunks_cohorts = find_group_cohorts (
@@ -1344,12 +1366,14 @@ def dask_groupby_agg(
1344
1366
tree_reduce (
1345
1367
reindexed ,
1346
1368
combine = partial (combine , agg = agg , reindex = True ),
1347
- aggregate = partial (aggregate , expected_groups = index , reindex = True ),
1369
+ aggregate = partial (
1370
+ aggregate , expected_groups = index , reindex = True , return_array = True
1371
+ ),
1348
1372
)
1349
1373
)
1350
1374
groups_ .append (cohort )
1351
1375
1352
- reduced = dask .array .concatenate (reduced_ , axis = - 1 )
1376
+ result = dask .array .concatenate (reduced_ , axis = - 1 )
1353
1377
groups = (np .concatenate (groups_ ),)
1354
1378
group_chunks = (tuple (len (cohort ) for cohort in groups_ ),)
1355
1379
@@ -1375,21 +1399,24 @@ def dask_groupby_agg(
1375
1399
for ax , chunks in zip (axis , group_chunks ):
1376
1400
adjust_chunks [ax ] = chunks
1377
1401
1378
- result = dask .array .blockwise (
1379
- _extract_result ,
1380
- inds [: - len (axis )] + (inds [- 1 ],),
1381
- reduced ,
1382
- inds ,
1383
- adjust_chunks = adjust_chunks ,
1384
- dtype = agg .dtype [agg .name ],
1385
- key = agg .name ,
1386
- name = f"{ name } -{ token } " ,
1387
- )
1388
-
1402
+ # result = dask.array.blockwise(
1403
+ # _extract_result,
1404
+ # inds[: -len(axis)] + (inds[-1],),
1405
+ # reduced,
1406
+ # inds,
1407
+ # adjust_chunks=adjust_chunks,
1408
+ # dtype=agg.dtype[agg.name],
1409
+ # key=agg.name,
1410
+ # name=f"{name}-{token}",
1411
+ # )
1389
1412
return (result , groups )
1390
1413
1391
1414
1392
1415
def _extract_result (result_dict : FinalResultsDict , key ) -> np .ndarray :
1416
+ from dask .array .core import deepfirst
1417
+
1418
+ if not isinstance (result_dict , dict ):
1419
+ result_dict = deepfirst (result_dict )
1393
1420
return result_dict [key ]
1394
1421
1395
1422
0 commit comments