@@ -136,7 +136,7 @@ def _get_optimal_chunks_for_groups(chunks, labels):
136136 return tuple (newchunks )
137137
138138
139- def _unique (a : np .ndarray ):
139+ def _unique (a : np .ndarray ) -> np . ndarray :
140140 """Much faster to use pandas unique and sort the results.
141141 np.unique sorts before uniquifying and is slow."""
142142 return np .sort (pd .unique (a .reshape (- 1 )))
@@ -816,8 +816,25 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict:
816816 return results
817817
818818
819+ def _find_unique_groups (x_chunk ) -> np .ndarray :
820+ from dask .base import flatten
821+ from dask .utils import deepmap
822+
823+ unique_groups = _unique (np .asarray (tuple (flatten (deepmap (listify_groups , x_chunk )))))
824+ unique_groups = unique_groups [~ isnull (unique_groups )]
825+
826+ if len (unique_groups ) == 0 :
827+ unique_groups = np .array ([np .nan ])
828+ return unique_groups
829+
830+
819831def _simple_combine (
820- x_chunk , agg : Aggregation , axis : T_Axes , keepdims : bool , is_aggregate : bool = False
832+ x_chunk ,
833+ agg : Aggregation ,
834+ axis : T_Axes ,
835+ keepdims : bool ,
836+ reindex : bool ,
837+ is_aggregate : bool = False ,
821838) -> IntermediateDict :
822839 """
823840 'Simple' combination of blockwise results.
@@ -830,8 +847,19 @@ def _simple_combine(
830847 4. At the final agggregate step, we squeeze out DUMMY_AXIS
831848 """
832849 from dask .array .core import deepfirst
850+ from dask .utils import deepmap
851+
852+ if not reindex :
853+ # We didn't reindex at the blockwise step
854+ # So now reindex before combining by reducing along DUMMY_AXIS
855+ unique_groups = _find_unique_groups (x_chunk )
856+ x_chunk = deepmap (
857+ partial (reindex_intermediates , agg = agg , unique_groups = unique_groups ), x_chunk
858+ )
859+ else :
860+ unique_groups = deepfirst (x_chunk )["groups" ]
833861
834- results : IntermediateDict = {"groups" : deepfirst ( x_chunk )[ "groups" ] }
862+ results : IntermediateDict = {"groups" : unique_groups }
835863 results ["intermediates" ] = []
836864 axis_ = axis [:- 1 ] + (DUMMY_AXIS ,)
837865 for idx , combine in enumerate (agg .combine ):
@@ -886,7 +914,6 @@ def _grouped_combine(
886914 sort : bool = True ,
887915) -> IntermediateDict :
888916 """Combine intermediates step of tree reduction."""
889- from dask .base import flatten
890917 from dask .utils import deepmap
891918
892919 if isinstance (x_chunk , dict ):
@@ -897,11 +924,7 @@ def _grouped_combine(
897924 # when there's only a single axis of reduction, we can just concatenate later,
898925 # reindexing is unnecessary
899926 # I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
900- unique_groups = _unique (np .array (tuple (flatten (deepmap (listify_groups , x_chunk )))))
901- unique_groups = unique_groups [~ isnull (unique_groups )]
902- if len (unique_groups ) == 0 :
903- unique_groups = [np .nan ]
904-
927+ unique_groups = _find_unique_groups (x_chunk )
905928 x_chunk = deepmap (
906929 partial (reindex_intermediates , agg = agg , unique_groups = unique_groups ), x_chunk
907930 )
@@ -1216,7 +1239,8 @@ def dask_groupby_agg(
12161239 # This allows us to discover groups at compute time, support argreductions, lower intermediate
12171240 # memory usage (but method="cohorts" would also work to reduce memory in some cases)
12181241
1219- do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction (agg )
1242+ do_simple_combine = not _is_arg_reduction (agg )
1243+
12201244 if method == "blockwise" :
12211245 # use the "non dask" code path, but applied blockwise
12221246 blockwise_method = partial (
@@ -1268,31 +1292,32 @@ def dask_groupby_agg(
12681292 if method in ["map-reduce" , "cohorts" ]:
12691293 combine : Callable [..., IntermediateDict ]
12701294 if do_simple_combine :
1271- combine = _simple_combine
1295+ combine = partial (_simple_combine , reindex = reindex )
1296+ combine_name = "simple-combine"
12721297 else :
12731298 combine = partial (_grouped_combine , engine = engine , sort = sort )
1299+ combine_name = "grouped-combine"
12741300
1275- # Each chunk of `reduced`` is really a dict mapping
1276- # 1. reduction name to array
1277- # 2. "groups" to an array of group labels
1278- # Note: it does not make sense to interpret axis relative to
1279- # shape of intermediate results after the blockwise call
12801301 tree_reduce = partial (
12811302 dask .array .reductions ._tree_reduce ,
1282- combine = partial (combine , agg = agg ),
1283- name = f"{ name } -reduce-{ method } " ,
1303+ name = f"{ name } -reduce-{ method } -{ combine_name } " ,
12841304 dtype = array .dtype ,
12851305 axis = axis ,
12861306 keepdims = True ,
12871307 concatenate = False ,
12881308 )
1289- aggregate = partial (
1290- _aggregate , combine = combine , agg = agg , fill_value = fill_value , reindex = reindex
1291- )
1309+ aggregate = partial (_aggregate , combine = combine , agg = agg , fill_value = fill_value )
1310+
1311+ # Each chunk of `reduced`` is really a dict mapping
1312+ # 1. reduction name to array
1313+ # 2. "groups" to an array of group labels
1314+ # Note: it does not make sense to interpret axis relative to
1315+ # shape of intermediate results after the blockwise call
12921316 if method == "map-reduce" :
12931317 reduced = tree_reduce (
12941318 intermediate ,
1295- aggregate = partial (aggregate , expected_groups = expected_groups ),
1319+ combine = partial (combine , agg = agg ),
1320+ aggregate = partial (aggregate , expected_groups = expected_groups , reindex = reindex ),
12961321 )
12971322 if is_duck_dask_array (by_input ) and expected_groups is None :
12981323 groups = _extract_unknown_groups (reduced , group_chunks = group_chunks , dtype = by .dtype )
@@ -1310,23 +1335,17 @@ def dask_groupby_agg(
13101335 reduced_ = []
13111336 groups_ = []
13121337 for blks , cohort in chunks_cohorts .items ():
1338+ index = pd .Index (cohort )
13131339 subset = subset_to_blocks (intermediate , blks , array .blocks .shape [- len (axis ) :])
1314- if do_simple_combine :
1315- # reindex so that reindex can be set to True later
1316- reindexed = dask .array .map_blocks (
1317- reindex_intermediates ,
1318- subset ,
1319- agg = agg ,
1320- unique_groups = cohort ,
1321- meta = subset ._meta ,
1322- )
1323- else :
1324- reindexed = subset
1325-
1340+ reindexed = dask .array .map_blocks (
1341+ reindex_intermediates , subset , agg = agg , unique_groups = index , meta = subset ._meta
1342+ )
1343+ # now that we have reindexed, we can set reindex=True explicitlly
13261344 reduced_ .append (
13271345 tree_reduce (
13281346 reindexed ,
1329- aggregate = partial (aggregate , expected_groups = cohort , reindex = reindex ),
1347+ combine = partial (combine , agg = agg , reindex = True ),
1348+ aggregate = partial (aggregate , expected_groups = index , reindex = True ),
13301349 )
13311350 )
13321351 groups_ .append (cohort )
@@ -1382,28 +1401,24 @@ def _validate_reindex(
13821401 if reindex is True :
13831402 if _is_arg_reduction (func ):
13841403 raise NotImplementedError
1385- if method == "blockwise" :
1386- raise NotImplementedError
1404+ if method in ["blockwise" , "cohorts" ]:
1405+ raise ValueError (
1406+ "reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
1407+ )
13871408
13881409 if reindex is None :
13891410 if method == "blockwise" or _is_arg_reduction (func ):
13901411 reindex = False
13911412
1392- elif expected_groups is not None :
1393- reindex = True
1394-
1395- elif method in ["split-reduce" , "cohorts" ]:
1396- reindex = True
1413+ elif method == "cohorts" :
1414+ reindex = False
13971415
13981416 elif method == "map-reduce" :
13991417 if expected_groups is None and by_is_dask :
14001418 reindex = False
14011419 else :
14021420 reindex = True
14031421
1404- if method in ["split-reduce" , "cohorts" ] and reindex is False :
1405- raise NotImplementedError
1406-
14071422 assert isinstance (reindex , bool )
14081423 return reindex
14091424
0 commit comments