@@ -1195,7 +1195,6 @@ def dask_groupby_agg(
11951195
11961196 import dask .array
11971197 from dask .array .core import slices_from_chunks
1198- from dask .highlevelgraph import HighLevelGraph
11991198
12001199 # I think _tree_reduce expects this
12011200 assert isinstance (axis , Sequence )
@@ -1268,6 +1267,9 @@ def dask_groupby_agg(
12681267 engine = engine ,
12691268 sort = sort ,
12701269 ),
1270+ # output indices are the same as input indices
1271+ # Unlike xhistogram, we don't always know what the size of the group
1272+ # dimension will be unless reindex=True
12711273 inds ,
12721274 array ,
12731275 inds ,
@@ -1277,7 +1279,7 @@ def dask_groupby_agg(
12771279 dtype = array .dtype , # this is purely for show
12781280 meta = array ._meta ,
12791281 align_arrays = False ,
1280- token = f"{ name } -chunk-{ token } " ,
1282+ name = f"{ name } -chunk-{ token } " ,
12811283 )
12821284
12831285 if expected_groups is None :
@@ -1364,35 +1366,63 @@ def dask_groupby_agg(
13641366 groups = (np .concatenate (groups_in_block ),)
13651367 ngroups_per_block = tuple (len (grp ) for grp in groups_in_block )
13661368 group_chunks = (ngroups_per_block ,)
1367-
13681369 else :
13691370 raise ValueError (f"Unknown method={ method } ." )
13701371
1371- # extract results from the dict
1372+ out_inds = inds [: - len ( axis )] + ( inds [ - 1 ],)
13721373 output_chunks = reduced .chunks [: - len (axis )] + group_chunks
1374+ if method == "blockwise" and len (axis ) > 1 :
1375+ # The final results are available but the blocks along axes
1376+ # need to be reshaped to axis=-1
1377+ # I don't know that this is possible with blockwise
1378+ # All other code paths benefit from an unmaterialized Blockwise layer
1379+ reduced = _collapse_blocks_along_axes (reduced , axis , group_chunks )
1380+
1381+ # Can't use map_blocks because it forces concatenate=True along drop_axes,
1382+ result = dask .array .blockwise (
1383+ _extract_result ,
1384+ out_inds ,
1385+ reduced ,
1386+ inds ,
1387+ adjust_chunks = dict (zip (out_inds , output_chunks )),
1388+ dtype = agg .dtype [agg .name ],
1389+ key = agg .name ,
1390+ name = f"{ name } -{ token } " ,
1391+ concatenate = False ,
1392+ )
1393+
1394+ return (result , groups )
1395+
1396+
1397+ def _collapse_blocks_along_axes (reduced , axis , group_chunks ):
1398+ import dask .array
1399+ from dask .highlevelgraph import HighLevelGraph
1400+
1401+ nblocks = tuple (reduced .numblocks [ax ] for ax in axis )
1402+ output_chunks = reduced .chunks [: - len (axis )] + ((1 ,) * (len (axis ) - 1 ),) + group_chunks
1403+
1404+ # extract results from the dict
13731405 ochunks = tuple (range (len (chunks_v )) for chunks_v in output_chunks )
13741406 layer2 : dict [tuple , tuple ] = {}
1375- agg_name = f"{ name } -{ token } "
1376- for ochunk in itertools .product (* ochunks ):
1377- if method == "blockwise" :
1378- if len (axis ) == 1 :
1379- inchunk = ochunk
1380- else :
1381- nblocks = tuple (len (array .chunks [ax ]) for ax in axis )
1382- inchunk = ochunk [:- 1 ] + np .unravel_index (ochunk [- 1 ], nblocks )
1383- else :
1384- inchunk = ochunk [:- 1 ] + (0 ,) * (len (axis ) - 1 ) + (ochunk [- 1 ],)
1407+ name = f"reshape-{ reduced .name } "
13851408
1386- layer2 [(agg_name , * ochunk )] = (operator .getitem , (reduced .name , * inchunk ), agg .name )
1409+ for ochunk in itertools .product (* ochunks ):
1410+ inchunk = ochunk [: - len (axis )] + np .unravel_index (ochunk [- 1 ], nblocks )
1411+ layer2 [(name , * ochunk )] = (reduced .name , * inchunk )
13871412
1388- result = dask .array .Array (
1389- HighLevelGraph .from_collections (agg_name , layer2 , dependencies = [reduced ]),
1390- agg_name ,
1413+ return dask .array .Array (
1414+ HighLevelGraph .from_collections (name , layer2 , dependencies = [reduced ]),
1415+ name ,
13911416 chunks = output_chunks ,
1392- dtype = agg .dtype [ agg . name ] ,
1417+ dtype = reduced .dtype ,
13931418 )
13941419
1395- return (result , groups )
1420+
1421+ def _extract_result (result_dict : FinalResultsDict , key ) -> np .ndarray :
1422+ from dask .array .core import deepfirst
1423+
1424+ # deepfirst should be not be needed here but sometimes we receive a list of dict?
1425+ return deepfirst (result_dict )[key ]
13961426
13971427
13981428def _validate_reindex (
0 commit comments