@@ -803,10 +803,15 @@ def _aggregate(
803803 keepdims ,
804804 fill_value : Any ,
805805 reindex : bool ,
806+ return_array : bool ,
806807) -> FinalResultsDict :
807808 """Final aggregation step of tree reduction"""
808809 results = combine (x_chunk , agg , axis , keepdims , is_aggregate = True )
809- return _finalize_results (results , agg , axis , expected_groups , fill_value , reindex )
810+ finalized = _finalize_results (results , agg , axis , expected_groups , fill_value , reindex )
811+ if return_array :
812+ return finalized [agg .name ]
813+ else :
814+ return finalized
810815
811816
812817def _expand_dims (results : IntermediateDict ) -> IntermediateDict :
@@ -1287,6 +1292,7 @@ def dask_groupby_agg(
12871292 group_chunks : tuple [tuple [Union [int , float ], ...]] = (
12881293 (len (expected_groups ),) if expected_groups is not None else (np .nan ,),
12891294 )
1295+ groups_are_unknown = is_duck_dask_array (by_input ) and expected_groups is None
12901296
12911297 if method in ["map-reduce" , "cohorts" ]:
12921298 combine : Callable [..., IntermediateDict ]
@@ -1316,16 +1322,32 @@ def dask_groupby_agg(
13161322 reduced = tree_reduce (
13171323 intermediate ,
13181324 combine = partial (combine , agg = agg ),
1319- aggregate = partial (aggregate , expected_groups = expected_groups , reindex = reindex ),
1325+ aggregate = partial (
1326+ aggregate ,
1327+ expected_groups = expected_groups ,
1328+ reindex = reindex ,
1329+ return_array = not groups_are_unknown ,
1330+ ),
13201331 )
1321- if is_duck_dask_array ( by_input ) and expected_groups is None :
1332+ if groups_are_unknown :
13221333 groups = _extract_unknown_groups (reduced , group_chunks = group_chunks , dtype = by .dtype )
1334+ result = dask .array .map_blocks (
1335+ _extract_result ,
1336+ reduced ,
1337+ chunks = reduced .chunks [: - len (axis )] + group_chunks ,
1338+ drop_axis = axis [:- 1 ],
1339+ dtype = agg .dtype [agg .name ],
1340+ key = agg .name ,
1341+ name = f"{ name } -{ token } " ,
1342+ )
1343+
13231344 else :
13241345 if expected_groups is None :
13251346 expected_groups_ = _get_expected_groups (by_input , sort = sort )
13261347 else :
13271348 expected_groups_ = expected_groups
13281349 groups = (expected_groups_ .to_numpy (),)
1350+ result = reduced
13291351
13301352 elif method == "cohorts" :
13311353 chunks_cohorts = find_group_cohorts (
@@ -1344,12 +1366,14 @@ def dask_groupby_agg(
13441366 tree_reduce (
13451367 reindexed ,
13461368 combine = partial (combine , agg = agg , reindex = True ),
1347- aggregate = partial (aggregate , expected_groups = index , reindex = True ),
1369+ aggregate = partial (
1370+ aggregate , expected_groups = index , reindex = True , return_array = True
1371+ ),
13481372 )
13491373 )
13501374 groups_ .append (cohort )
13511375
1352- reduced = dask .array .concatenate (reduced_ , axis = - 1 )
1376+ result = dask .array .concatenate (reduced_ , axis = - 1 )
13531377 groups = (np .concatenate (groups_ ),)
13541378 group_chunks = (tuple (len (cohort ) for cohort in groups_ ),)
13551379
@@ -1375,21 +1399,24 @@ def dask_groupby_agg(
13751399 for ax , chunks in zip (axis , group_chunks ):
13761400 adjust_chunks [ax ] = chunks
13771401
1378- result = dask .array .blockwise (
1379- _extract_result ,
1380- inds [: - len (axis )] + (inds [- 1 ],),
1381- reduced ,
1382- inds ,
1383- adjust_chunks = adjust_chunks ,
1384- dtype = agg .dtype [agg .name ],
1385- key = agg .name ,
1386- name = f"{ name } -{ token } " ,
1387- )
1388-
1402+ # result = dask.array.blockwise(
1403+ # _extract_result,
1404+ # inds[: -len(axis)] + (inds[-1],),
1405+ # reduced,
1406+ # inds,
1407+ # adjust_chunks=adjust_chunks,
1408+ # dtype=agg.dtype[agg.name],
1409+ # key=agg.name,
1410+ # name=f"{name}-{token}",
1411+ # )
13891412 return (result , groups )
13901413
13911414
13921415def _extract_result (result_dict : FinalResultsDict , key ) -> np .ndarray :
1416+ from dask .array .core import deepfirst
1417+
1418+ if not isinstance (result_dict , dict ):
1419+ result_dict = deepfirst (result_dict )
13931420 return result_dict [key ]
13941421
13951422
0 commit comments